msaeed3 commited on
Commit
e295beb
·
1 Parent(s): 6f01ce4

version 1.0

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. arabic/decode_one_image.py +225 -0
  2. arabic/page_htr.py +283 -0
  3. arabic/post_process_routines.py +301 -0
  4. arabic/test_hw_helper_routines.py +147 -0
  5. arabic/warp_routines.py +368 -0
  6. coords/__init__.py +0 -0
  7. coords/points.py +342 -0
  8. coords/poly_routines.py +195 -0
  9. coords/text_cleaning_routines.py +82 -0
  10. coords/text_gt.py +101 -0
  11. model/trial_26_A/muharaf_charset.json +320 -0
  12. model/trial_26_A/set0/config_2600.yaml +25 -0
  13. model/trial_26_A/set0/pretrain/hw.pt +3 -0
  14. model/trial_26_A/set0/pretrain/lf.pt +3 -0
  15. model/trial_26_A/set0/pretrain/sol.pt +3 -0
  16. py3/e2e/__init__.py +0 -0
  17. py3/e2e/alignment_dataset.py +69 -0
  18. py3/e2e/e2e_model.py +207 -0
  19. py3/e2e/e2e_postprocessing.py +182 -0
  20. py3/e2e/forward_pass.py +86 -0
  21. py3/e2e/handwriting_alignment_loss.py +125 -0
  22. py3/e2e/nms.py +162 -0
  23. py3/e2e/validation_utils.py +137 -0
  24. py3/e2e/visualization.py +176 -0
  25. py3/hw/__init__.py +0 -0
  26. py3/hw/cnn_lstm.py +117 -0
  27. py3/lf/__init__.py +0 -0
  28. py3/lf/fast_patch_view.py +96 -0
  29. py3/lf/lf_cnn.py +45 -0
  30. py3/lf/line_follower.py +181 -0
  31. py3/lf/models/__init__.py +36 -0
  32. py3/lf/models/res_unet.py +147 -0
  33. py3/lf/models/resnet.py +335 -0
  34. py3/lf/models/tools.py +144 -0
  35. py3/lf/stn/__init__.py +0 -0
  36. py3/lf/stn/gridgen.py +126 -0
  37. py3/sol/__init__.py +0 -0
  38. py3/sol/crop_transform.py +35 -0
  39. py3/sol/crop_utils.py +48 -0
  40. py3/sol/start_of_line_finder.py +42 -0
  41. py3/sol/vgg.py +157 -0
  42. py3/utils/__init__.py +0 -0
  43. py3/utils/character_set.ipynb +539 -0
  44. py3/utils/character_set.py +61 -0
  45. py3/utils/continuous_state.py +87 -0
  46. py3/utils/dataset_parse.py +17 -0
  47. py3/utils/dataset_wrapper.py +27 -0
  48. py3/utils/error_rates.py +21 -0
  49. py3/utils/fast_inverse.py +58 -0
  50. py3/utils/safe_load.py +30 -0
arabic/decode_one_image.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+
5
+ from utils.continuous_state import init_model
6
+ from e2e import e2e_model, e2e_postprocessing, visualization
7
+ from e2e.e2e_model import E2EModel
8
+
9
+ import torch
10
+ from torch import nn
11
+ from torch.autograd import Variable
12
+
13
+ import json
14
+ import cv2
15
+ import numpy as np
16
+
17
+ import codecs
18
+ import yaml
19
+
20
+
21
+ from collections import defaultdict
22
+ import operator
23
+ import pandas as pd
24
+ from utils import error_rates
25
+ import matplotlib.pyplot as plt
26
+ import argparse
27
+
28
+
29
+ # Network output on one image
30
+ # Will read from file if org_img is none
31
+ def network_output(config_file, image_path, model_mode = "best_overall",
32
+ flip=False, use_unet=False, org_img=None, device="cuda"):
33
+
34
+ with open(config_file) as f:
35
+ config = yaml.load(f, Loader=yaml.Loader)
36
+
37
+ if use_unet:
38
+ config['network']['lf']['u_net'] = True
39
+ #print('config changed')
40
+
41
+ char_set_path = config['network']['hw']['char_set_path']
42
+ ### Change hw's num_of_outputs in config
43
+ with open(char_set_path) as f:
44
+ char_set = json.load(f)
45
+
46
+ config["network"]["hw"]["num_of_outputs"] = len(char_set['idx_to_char']) + 1
47
+
48
+ dtype =torch.FloatTensor
49
+ if 'cuda' in device:
50
+ dtype =torch.cuda.FloatTensor
51
+
52
+ sol, lf, hw = init_model(config, sol_dir=model_mode, lf_dir=model_mode, hw_dir=model_mode,
53
+ device=device)
54
+
55
+ e2e = E2EModel(sol, lf, hw, dtype=dtype, device=device)
56
+
57
+ e2e.eval()
58
+
59
+ if org_img is None:
60
+ org_img = cv2.imread(image_path)
61
+ if flip:
62
+ org_img = cv2.flip(org_img, 1)
63
+
64
+ target_dim1 = 512
65
+ s = target_dim1 / float(org_img.shape[1])
66
+
67
+ pad_amount = 128
68
+ org_img = np.pad(org_img, ((pad_amount,pad_amount),(pad_amount,pad_amount), (0,0)), 'constant', constant_values=255)
69
+ before_padding = org_img
70
+
71
+ target_dim0 = int(org_img.shape[0] * s)
72
+ target_dim1 = int(org_img.shape[1] * s)
73
+
74
+ full_img = org_img.astype(np.float32)
75
+ full_img = full_img.transpose([2,1,0])[None,...]
76
+ full_img = torch.from_numpy(full_img)
77
+ full_img = full_img / 128 - 1
78
+
79
+
80
+ img = cv2.resize(org_img,(target_dim1, target_dim0), interpolation = cv2.INTER_CUBIC)
81
+ img = img.astype(np.float32)
82
+ img = img.transpose([2,1,0])[None,...]
83
+ img = torch.from_numpy(img)
84
+ img = img / 128 - 1
85
+
86
+ out = e2e.forward({
87
+ "resized_img": img,
88
+ "full_img": full_img,
89
+ "resize_scale": 1.0/s
90
+ }, use_full_img=True, device=device)
91
+
92
+ out = e2e_postprocessing.results_to_numpy(out)
93
+
94
+ if out is None:
95
+ print ("No Results")
96
+ return
97
+
98
+ # take into account the padding
99
+ out['sol'][:,:2] = out['sol'][:,:2] - pad_amount
100
+ for l in out['lf']:
101
+ l[:,:2,:2] = l[:,:2,:2] - pad_amount
102
+
103
+ out['image_path'] = image_path
104
+
105
+ return out
106
+
107
+
108
+ def decode_one_img_with_info(config_path, out, visualize=False, flip=False, org_img=None, device="cuda"):
109
+
110
+ with open(config_path) as f:
111
+ config = yaml.load(f, Loader=yaml.Loader)
112
+
113
+ char_set_path = config['network']['hw']['char_set_path']
114
+
115
+
116
+
117
+ with open(char_set_path) as f:
118
+ char_set = json.load(f)
119
+
120
+ idx_to_char = {}
121
+ for k,v in char_set['idx_to_char'].items():
122
+ idx_to_char[int(k)] = v
123
+
124
+ out = dict(out)
125
+
126
+ image_path = str(out['image_path'])
127
+ #print(image_path)
128
+ if org_img is None:
129
+ org_img = cv2.imread(image_path)
130
+ if flip:
131
+ org_img = cv2.flip(org_img, 1)
132
+
133
+ # Postprocessing Steps
134
+ out['idx'] = np.arange(out['sol'].shape[0])
135
+ out = e2e_postprocessing.trim_ends(out)
136
+ e2e_postprocessing.filter_on_pick(out, e2e_postprocessing.select_non_empty_string(out))
137
+ out = e2e_postprocessing.postprocess(out,
138
+ sol_threshold=config['post_processing']['sol_threshold'],
139
+ lf_nms_params={
140
+ "overlap_range": config['post_processing']['lf_nms_range'],
141
+ "overlap_threshold": config['post_processing']['lf_nms_threshold']
142
+ }
143
+ )
144
+ order = e2e_postprocessing.read_order(out)
145
+ e2e_postprocessing.filter_on_pick(out, order)
146
+
147
+ # Get output strings and CER
148
+ output_strings = []
149
+ output_strings, decoded_raw_hw = e2e_postprocessing.decode_handwriting(out, idx_to_char)
150
+ return out, output_strings
151
+
152
+ def write_line_images(images, parent_img_fullpath, result_dir='Result', flip=True):
153
+ directory = os.path.dirname(parent_img_fullpath)
154
+ parent_basename = os.path.basename(parent_img_fullpath)
155
+ dir_basename = parent_basename[0:parent_basename.rfind('_')]
156
+ result_dir = os.path.join(directory, result_dir)
157
+ new_directory = os.path.join(directory, result_dir, dir_basename)
158
+ if not os.path.exists(result_dir):
159
+ os.mkdir(result_dir)
160
+ if not os.path.exists(new_directory):
161
+ os.mkdir(new_directory)
162
+ # Get rid of file extension
163
+ parent_basename = parent_basename[0:parent_basename.rfind('.')]
164
+ for ind, img in enumerate(images):
165
+
166
+ filename = os.path.join(new_directory,
167
+ parent_basename + '_' + str(ind) + '.png')
168
+
169
+ if flip:
170
+ img = cv2.flip(img, 1)
171
+ cv2.imwrite(filename, img)
172
+
173
+
174
+ # Write a one time header if needed
175
+ def write_empty(result_file):
176
+ result_df = pd.DataFrame(columns = ['image_file', 'ground_truth','prediction',
177
+ 'region_type',
178
+ # 'gt_poly',
179
+ 'lf_points', 'beginning', 'ending', 'SOL'])
180
+ result_df.to_csv(result_file, index=False)
181
+
182
+
183
+ # TODO: Verify that (x,y) are the first two of SOL output
184
+ def add_offset_to_sol(sol, offset):
185
+ for ind in range(len(sol)):
186
+ sol[ind][0] += offset[0]
187
+ sol[ind][1] += offset[1]
188
+ return sol
189
+
190
+
191
+ def add_offset_to_lf(lf, offset):
192
+ for pt_ind in range(len(lf)):
193
+ for line_ind in range(len(lf[0])):
194
+ lf[pt_ind][line_ind][0][0] = lf[pt_ind][line_ind][0][0] + offset[0]
195
+ lf[pt_ind][line_ind][1][0] = lf[pt_ind][line_ind][1][0] + offset[1]
196
+ lf[pt_ind][line_ind][0][1] = lf[pt_ind][line_ind][0][1] + offset[0]
197
+ lf[pt_ind][line_ind][1][1] = lf[pt_ind][line_ind][1][1] + offset[1]
198
+ return lf
199
+
200
+ # The merged output will be in out1
201
+ def merge_out(out1, out2, offset):
202
+ lf1 = out1['lf']
203
+ lf2 = add_offset_to_lf(out2['lf'], offset)
204
+ out1['lf'].extend(lf2)
205
+
206
+ out1['beginning'] = np.concatenate((out1['beginning'], out2['beginning']))
207
+ out1['ending'] = np.concatenate((out1['ending'], out2['ending']))
208
+
209
+ sol2 = add_offset_to_sol(out2['sol'], offset)
210
+ out1['sol'] = np.vstack((out1['sol'], sol2))
211
+ return out1
212
+
213
+ def split_image_horizontal(img_file):
214
+ f1 = img_file[:-4] + '_1' + '.jpg'
215
+ f2 = img_file[:-4] + '_2' + '.jpg'
216
+ img = cv2.imread(img_file)
217
+ [ht, width, colors] = img.shape
218
+ ht1 = int(ht/2)
219
+ img1 = img[:ht1, :, :]
220
+ img2 = img[ht1:, :, :]
221
+ cv2.imwrite(f1, img1)
222
+ cv2.imwrite(f2, img2)
223
+ return f1, f2, ht1
224
+
225
+
arabic/page_htr.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('py3/')
3
+ sys.path.append('coords')
4
+ import test_hw_helper_routines as test_hw
5
+ import json
6
+ import os
7
+ import torch
8
+ import matplotlib.pyplot as plt
9
+ import cv2
10
+ import sys
11
+ import pandas as pd
12
+ import numpy as np
13
+ import decode_one_image as decode
14
+ import post_process_routines as post
15
+ import points
16
+ import warp_routines as warp
17
+ from datetime import datetime, timezone
18
+ from utils import error_rates
19
+ import text_cleaning_routines as clean
20
+ import time
21
+ import argparse
22
+
23
+ def add_meta(json_obj):
24
+
25
+ # Get the current date and time in UTC
26
+ now_utc = datetime.now(timezone.utc)
27
+ # Format the date and time as a string
28
+ formatted_date_utc = now_utc.strftime('%Y-%m-%dT%H:%M:%S')
29
+
30
+ json_obj['timeStamps'] = {"created": formatted_date_utc,
31
+ "lastEdited": "",
32
+ "submitted": "",
33
+ "checked": ""
34
+ }
35
+
36
+ json_obj["annotators"] = {
37
+ "creator": "SFR",
38
+ "lastEditor": "",
39
+ "transcriber": "",
40
+ "transcription_QA": "",
41
+ "transcription_tagging": "",
42
+ "transcription_tagging_QA": "" }
43
+
44
+ return json_obj
45
+
46
+ def reset_time(json_obj):
47
+ json_obj["time"] = 0
48
+ for line in json_obj:
49
+ if line.startswith("line_"):
50
+ json_obj[line]["transcribeTime"] = 0
51
+ json_obj[line]["annotateTime"] = 0
52
+ json_obj[line]["edited"] = "0"
53
+ return json_obj
54
+
55
+ def get_hw(config_file, device="cuda"):
56
+ config = test_hw.get_config(config_file)
57
+
58
+ idx_to_char = test_hw.load_char_set(config['network']['hw']['char_set_path'])
59
+ if 'hw_to_save' in config['pretraining'].keys():
60
+ pt_file = config['pretraining']['hw_to_save']
61
+ else:
62
+ pt_file = 'hw.pt'
63
+ pt_filename = os.path.join(config['pretraining']['snapshot_path'], pt_file)
64
+
65
+ config["network"]["hw"]["num_of_outputs"] = len(idx_to_char) + 1
66
+
67
+ print('...Using snapshot', pt_filename)
68
+ HW = test_hw.load_HW(config['network']['hw'], pt_filename)
69
+ device = torch.device(device)
70
+ HW.to(device)
71
+ HW.eval()
72
+ return HW, idx_to_char
73
+
74
+ def sort_lines(json_obj, copy_lines_with_text_only=False):
75
+ top_left = []
76
+ keys = []
77
+ for k, v in json_obj.items():
78
+ if k.startswith('line_'):
79
+ keys.append(k)
80
+ poly = np.array(points.list_to_xy(v['coord']))
81
+ top_left.append([np.max(poly, 0)[0], np.min(poly, 0)[1]])
82
+ sorted_indices = sorted(range(len(top_left)),
83
+ key=lambda i: (top_left[i][1], -top_left[i][0]))
84
+
85
+ sorted_json = dict()
86
+ # Copy all non-line keys
87
+ for k, v in json_obj.items():
88
+ if not k.startswith('line_'):
89
+ sorted_json[k] = v
90
+ # Copy all lines
91
+ for i, ind in enumerate(sorted_indices):
92
+ if copy_lines_with_text_only and len(json_obj[keys[ind]]['text']) == 0:
93
+ continue
94
+ sorted_json[f'line_{i + 1}'] = json_obj[keys[ind]]
95
+
96
+ return sorted_json
97
+
98
+
99
+ # This will not do line annotations...only transcriptions
100
+ def complete_annotations_for_directory(input_dir, config_file,
101
+ annotator, model_mode="pretrain",
102
+ do_all=False):
103
+ total_done = 0
104
+ files = os.listdir(input_dir)
105
+ files.sort()
106
+ HW, idx_to_char = get_hw(config_file)
107
+
108
+ for f in files:
109
+ if not f.lower().endswith('.jpg'):
110
+ continue
111
+ img_file = os.path.join(input_dir, f)
112
+ json_file = img_file[:-4] + '_annotate_' + annotator + '.json'
113
+ if not os.path.exists(json_file):
114
+ print('No Json for', img_file)
115
+ continue
116
+ print('doing', img_file)
117
+
118
+ with open(json_file) as fin:
119
+ json_obj = json.load(fin)
120
+ # Add meta information and reset the timings in json
121
+ #json_obj = add_meta(json_obj)
122
+ #json_obj = reset_time(json_obj)
123
+
124
+
125
+ for line, values in json_obj.items():
126
+ if not line.startswith('line_'):
127
+ continue
128
+ if not do_all and len(values['text']) > 0:
129
+ continue
130
+ img = cv2.imread(img_file)
131
+ line_img = warp.get_line_image(values['coord'], img)
132
+ line_text = test_hw.get_predicted_str(HW, None, idx_to_char, flip=True,
133
+ img=line_img, read_image=False)
134
+
135
+ line_text_logical_order = clean.get_clean_visual_order(line_text)
136
+ json_obj[line]['text'] = line_text_logical_order
137
+
138
+ json_obj = sort_lines(json_obj)
139
+ with open(json_file, 'w') as fout:
140
+ json.dump(json_obj, fout, indent=2)
141
+
142
+ total_done += 1
143
+ return total_done
144
+
145
+
146
+ # This will do bulk annotation in the whole dir
147
+ def predict_annotations_for_directory(input_dir, config_file, annotator, model_mode="pretrain",
148
+ skip_if_json_exists=False, device="cuda"):
149
+ print('skip_if_json_exists', skip_if_json_exists)
150
+ files = os.listdir(input_dir)
151
+ files.sort()
152
+ done = 0
153
+ for f in files:
154
+ if not f.lower().endswith('.jpg'):
155
+ continue
156
+ img_file = os.path.join(input_dir, f)
157
+ image_arr = cv2.imread(img_file)
158
+ #plt.imshow(image_arr)
159
+
160
+ json_file = img_file[:-4] + '_annotate_' + annotator + '.json'
161
+ if os.path.exists(json_file) and skip_if_json_exists:
162
+ print('already done', json_file)
163
+ continue
164
+ print('doing', img_file)
165
+ out = decode.network_output(config_file, img_file, flip=True, model_mode=model_mode, device=device)
166
+ out, predicted_text = decode.decode_one_img_with_info(config_file, out, flip=True, device=device)
167
+
168
+ poly_list = post.get_polygon_list_tuples(out)
169
+
170
+ # Get rid of degenerate points
171
+ to_del_ind = []
172
+ for ind, p in enumerate(poly_list):
173
+ if len(p) < 3:
174
+ to_del_ind.append(ind)
175
+
176
+ if len(to_del_ind) > 0:
177
+ print('Deleting poly at index', to_del_ind)
178
+ poly_list = [poly_list[i] for i in range(len(poly_list)) if i not in to_del_ind]
179
+ predicted_text = [predicted_text[i] for i in range(len(predicted_text)) if i not in to_del_ind]
180
+
181
+ del_list, poly_list = post.get_poly_no_overlap(img_file, poly_list, 0.7)
182
+
183
+ if len(del_list) > 0:
184
+ print('polygons deleted', len(del_list), del_list)
185
+ print(len(poly_list))
186
+
187
+ predicted_text = [predicted_text[i] for i in range(len(predicted_text)) if i not in del_list]
188
+ poly_list = post.flip_polygon(img_file, poly_list)
189
+ #post.draw_image_with_poly("", img_file, poly_list, convert=False)
190
+ page_json = post.create_annotations_json(predicted_text, poly_list)
191
+
192
+ with open(json_file, 'w') as fout:
193
+ json.dump(page_json, fout)
194
+ done += 1
195
+
196
+ return done
197
+
198
+
199
+ def page_htr_one_file(img_file, config_file, model_mode="pretrain", device="cuda"):
200
+
201
+ image_arr = cv2.imread(img_file)
202
+
203
+ out = decode.network_output(config_file, img_file, flip=True, model_mode=model_mode, device=device)
204
+ out, predicted_text = decode.decode_one_img_with_info(config_file, out, flip=True, device=device)
205
+
206
+ poly_list = post.get_polygon_list_tuples(out)
207
+
208
+ # Get rid of degenerate points
209
+ to_del_ind = []
210
+ for ind, p in enumerate(poly_list):
211
+ if len(p) < 3:
212
+ to_del_ind.append(ind)
213
+
214
+ if len(to_del_ind) > 0:
215
+ #print('Deleting poly at index', to_del_ind)
216
+ poly_list = [poly_list[i] for i in range(len(poly_list)) if i not in to_del_ind]
217
+ predicted_text = [predicted_text[i] for i in range(len(predicted_text)) if i not in to_del_ind]
218
+
219
+
220
+ del_list, poly_list = post.get_poly_no_overlap(img_file, poly_list, 0.7)
221
+ predicted_text = [predicted_text[i] for i in range(len(predicted_text)) if i not in del_list]
222
+ predicted_text = [clean.get_clean_visual_order(txt) for txt in predicted_text]
223
+ poly_list = post.flip_polygon(img_file, poly_list)
224
+ #post.draw_image_with_poly("", img_file, poly_list, convert=False)
225
+ page_json = post.create_annotations_json(predicted_text, poly_list)
226
+
227
+ torch.cuda.empty_cache()
228
+ return page_json
229
+
230
+
231
+ def hw_one_file(img_file, config_file, json_obj, model_mode="pretrain", line_key=None):
232
+
233
+ HW, idx_to_char = get_hw(config_file)
234
+
235
+
236
+ for line, values in json_obj.items():
237
+ if not line.startswith('line_'):
238
+ continue
239
+ # IF line_key is specified then modify only that line
240
+ if line_key is not None and len(line_key) > 0. and line != line_key:
241
+ continue
242
+
243
+ img = cv2.imread(img_file)
244
+ line_img = warp.get_line_image(values['coord'], img)
245
+ line_text = test_hw.get_predicted_str(HW, None, idx_to_char, flip=True,
246
+ img=line_img, read_image=False)
247
+ line_text = make_manual_text_correction(line_text)
248
+ # New change
249
+ line_text_logical_order = clean.get_clean_visual_order(line_text)
250
+ json_obj[line]['text'] = line_text_logical_order
251
+
252
+ json_obj = sort_lines(json_obj)
253
+ torch.cuda.empty_cache()
254
+ return json_obj
255
+
256
+
257
+
258
+ if __name__ == "__main__":
259
+
260
+
261
+ parser = argparse.ArgumentParser(description="Run HTR module")
262
+
263
+ parser.add_argument("--line_htr", type=int, required=True, help="If 1, do line_htr else do page_htr")
264
+ parser.add_argument("--img_path", type=str, required=True, help="Image path")
265
+ parser.add_argument("--config_file", type=str, required=True, help="SFR_Arabic config file")
266
+ parser.add_argument("--original_json", type=str, required=True, help="Original JSON")
267
+ parser.add_argument("--line_key", type=str, required=True, help="line key")
268
+
269
+
270
+ args = parser.parse_args()
271
+ json_obj = {}
272
+ if args.line_htr == 1:
273
+ json_obj = json.loads(args.original_json)
274
+ json_obj = hw_one_file(args.img_path, args.config_file, json_obj,
275
+ model_mode="pretrain", line_key=args.line_key)
276
+ else:
277
+ json_obj = json.loads(args.original_json)
278
+ json_obj = page_htr_one_file(args.img_path, args.config_file, device="cuda")
279
+
280
+ print('BEGIN_OUT')
281
+ print(json.dumps(json_obj))
282
+
283
+ # python3 arabic/page_htr.py --line_htr 0 --img_path ../../datasets/kclds/KEllis/bk_on_server/bk/KEllis2018-150a.jpg --config_file model/trial_26_A/set0/config_2600.yaml --original_json {} --line_key 0
arabic/post_process_routines.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import json
4
+ import sys
5
+ sys.path.append('coords/')
6
+ import points
7
+ from PIL import Image, ImageDraw
8
+ import os
9
+ import matplotlib.pyplot as plt
10
+
11
+ from utils import error_rates
12
+ from arabic_reshaper import ArabicReshaper
13
+
14
+ # This file has routines for working with CERs and polygons
15
+
16
+ def correct_pt(value, max_value):
17
+ boundary = False
18
+ if value < 0:
19
+ value = 0
20
+ boundary = True
21
+ if value >= max_value:
22
+ value = max_value - 1
23
+ boundary = True
24
+ return [value, boundary]
25
+
26
+ # Each polygon is a list of (x,y) tuples
27
+ def get_polygon_list_tuples(out):
28
+ img = cv2.imread(out["image_path"])
29
+ img_height, img_width = img.shape[:2]
30
+ polygon_list = []
31
+ prev = [-1, -1]
32
+ for line_ind in range(len(out['lf'][0])):
33
+ polygon = []
34
+ begin_ind = out['beginning'][line_ind]
35
+ end_ind = out['ending'][line_ind]
36
+ begin_ind = int(np.floor(begin_ind))
37
+ end_ind = int(np.ceil(end_ind))
38
+ end_ind = min(end_ind, len(out['lf'])-1)
39
+ for pt_ind in range(begin_ind, end_ind+1):
40
+ pt_x = float(out['lf'][pt_ind][line_ind][0][0])
41
+ pt_y = float(out['lf'][pt_ind][line_ind][1][0])
42
+ pt_x, boundary_x = correct_pt(pt_x, img_width)
43
+ pt_y, boundary_y = correct_pt(pt_y, img_height)
44
+ if prev != [pt_x, pt_y]:
45
+ polygon.append((pt_x, pt_y))
46
+ prev = [pt_x, pt_y]
47
+ for pt_ind in range(end_ind, begin_ind-1, -1):
48
+ pt_x = float(out['lf'][pt_ind][line_ind][0][1])
49
+ pt_y = float(out['lf'][pt_ind][line_ind][1][1])
50
+ pt_x, boundary_x = correct_pt(pt_x, img_width)
51
+ pt_y, boundary_y = correct_pt(pt_y, img_height)
52
+ if prev != [pt_x, pt_y]:
53
+ polygon.append((pt_x, pt_y))
54
+ prev = [pt_x, pt_y]
55
+
56
+ polygon_list.append(polygon)
57
+ if len(polygon) < 3:
58
+ print('WARNING: DEGENERATE POLYGON AT INDEX', len(polygon_list))
59
+ return polygon_list
60
+
61
+ # Each polygon is a list of (x,y) tuples
62
+ def get_polygon_list_without_trim(out):
63
+ img = cv2.imread(out["image_path"])
64
+ img_height, img_width = img.shape[:2]
65
+ polygon_list = []
66
+ for line_ind in range(len(out['lf'][0])):
67
+ polygon = []
68
+ begin_ind = 0
69
+ end_ind = len(out['lf'])-1
70
+ prev = [-1, -1]
71
+
72
+ for pt_ind in range(begin_ind, end_ind+1):
73
+ pt_x = float(out['lf'][pt_ind][line_ind][0][0])
74
+ pt_y = float(out['lf'][pt_ind][line_ind][1][0])
75
+ pt_x, boundary_x = correct_pt(pt_x, img_width)
76
+ pt_y, boundary_y = correct_pt(pt_y, img_height)
77
+ if prev != [pt_x, pt_y]:
78
+ polygon.append((pt_x, pt_y))
79
+ prev = [pt_x, pt_y]
80
+ for pt_ind in range(end_ind, begin_ind-1, -1):
81
+ pt_x = float(out['lf'][pt_ind][line_ind][0][1])
82
+ pt_y = float(out['lf'][pt_ind][line_ind][1][1])
83
+ pt_x, boundary_x = correct_pt(pt_x, img_width)
84
+ pt_y, boundary_y = correct_pt(pt_y, img_height)
85
+ if prev != [pt_x, pt_y]:
86
+ polygon.append((pt_x, pt_y))
87
+ prev = [pt_x, pt_y]
88
+
89
+ if len(polygon) >= 3:
90
+ polygon_list.append(polygon)
91
+ return polygon_list
92
+
93
+ # Each polygon passed as input is a list of (x,y) tuples
94
+ # Same for output
95
+ def percent_intersection(size, poly1, poly2):
96
+ im1 = Image.new(mode="1", size=size)
97
+ draw1 = ImageDraw.Draw(im1)
98
+ draw1.polygon(poly1, fill=1)
99
+ im2 = Image.new(mode="1", size=size)
100
+ draw2 = ImageDraw.Draw(im2)
101
+ draw2.polygon(poly2, fill=1)
102
+ mask1 = np.asarray(im1, dtype=bool)
103
+ mask2 = np.asarray(im2, dtype=bool)
104
+ intersection_mask = mask1 & mask2
105
+ #plt.imshow(intersection)
106
+ intersection_area = intersection_mask.sum()
107
+ percent1 = intersection_area / mask1.sum()
108
+ percent2 = intersection_area / mask2.sum()
109
+ return intersection_area, percent1, percent2
110
+
111
+
112
+
113
+ def get_poly_no_overlap(img_name, poly_list, threshold=0.6):
114
+
115
+ img=Image.open(img_name)
116
+ size=img.size
117
+ #polygons = [points.list_to_xy(p) for p in poly_list]
118
+ polygons = poly_list
119
+ del_list = []
120
+ current = 0
121
+ next_ind = current+1
122
+ last_deleted = -1
123
+ while next_ind<len(polygons):
124
+ # Check these are not degernate polygons
125
+ if len(polygons[current]) < 3:
126
+ del_list.append(current)
127
+ current, next_ind = (current+1, next_ind+1)
128
+ continue
129
+ if len(polygons[next_ind]) < 3:
130
+ del_list.append(next_ind)
131
+ next_ind += 1
132
+ continue
133
+ # End check
134
+ overlap_area, percent1, percent2 = percent_intersection(size,
135
+ polygons[current],
136
+ polygons[next_ind])
137
+
138
+ if percent1 > threshold or percent2 > threshold:
139
+ to_del = current if percent1 > percent2 else next_ind
140
+ current, next_ind = (current, next_ind+1) if percent1<percent2\
141
+ else (next_ind, next_ind+1)
142
+ del_list.append(to_del)
143
+ last_deleted = to_del
144
+ #print('last deleted', to_del)
145
+ else: # when no overlap is found
146
+ current, next_ind = (current+1, next_ind+1)
147
+ if current <= last_deleted:
148
+ current = last_deleted + 1
149
+ next_ind = current + 1
150
+ all_ind = set(range(len(poly_list)))
151
+ good_ind = all_ind.difference(set(del_list))
152
+ poly_non_overlapping = [poly_list[i] for i in good_ind]
153
+ return del_list, poly_non_overlapping
154
+
155
+
156
+
157
+ def dump_polygons_json(out, polygons = None, filename=None):
158
+ if filename is None:
159
+ filename = out["image_path"][:-3] + "json"
160
+ if polygons is None:
161
+ polygons = get_polygon_list(out)
162
+ lf_dict = {}
163
+ for ind, poly in enumerate(polygons):
164
+ lf_dict['line_' + str(ind+1)] = points.xy_to_list(poly)
165
+
166
+ with open(filename, 'w') as fout:
167
+ json_dumps_str = json.dumps(lf_dict, indent=2)
168
+ #print('....json_dumps_str', json_dumps_str)
169
+ print(json_dumps_str, file=fout)
170
+
171
+ def write_json_file(out, poly_list = None, json_file=None):
172
+ if poly_list is None:
173
+ poly_list = get_polygon_list(out)
174
+ dump_polygons_json(out, poly_list, json_file)
175
+
176
+ def write_text_file(out, predicted_text, filename=None):
177
+ if filename is None:
178
+ filename = out["image_path"][:-3] + "txt"
179
+ prediction_para = '\n'.join(predicted_text)
180
+ with open(filename, 'w') as f:
181
+ f.write(prediction_para)
182
+
183
+ # won't flip the polygons...only the image
184
+ def draw_image_with_poly(directory, image, poly, convert=True, flip=False):
185
+ img = cv2.imread(os.path.join(directory, image))
186
+ if flip:
187
+ img = cv2.flip(img, 1)
188
+ plt.imshow(img)
189
+ colors = ['red', 'green', 'blue']
190
+
191
+ for ind, p in enumerate(poly):
192
+ if convert:
193
+ p = points.list_to_xy(p)
194
+ points.draw_poly(plt, p, colors[ind%3])
195
+ plt.text(p[-1][0], p[-1][1], str(ind))
196
+
197
+
198
+ arabic_reshaper_configuration = {
199
+ 'delete_harakat': True,
200
+ 'support_ligatures': True
201
+ }
202
+
203
+
204
+ def remove_diacritics(txt):
205
+ reshaper = ArabicReshaper(configuration=arabic_reshaper_configuration)
206
+ txt_without_diacritics = reshaper.reshape(txt)
207
+ return(txt_without_diacritics)
208
+
209
+ def get_cer_matrix(gt_list, prediction_list):
210
+ total_gt = len(gt_list)
211
+ total_predictions = len(prediction_list)
212
+ cer_matrix = np.zeros((total_gt, total_predictions))
213
+
214
+ for ind_g, g in enumerate(gt_list):
215
+ for ind_p, p in enumerate(prediction_list):
216
+ cer_matrix[ind_g, ind_p] = error_rates.cer(g, p)
217
+ if cer_matrix[ind_g, ind_p] > 1:
218
+ cer_matrix[ind_g, ind_p] = 1
219
+
220
+ return cer_matrix
221
+
222
+
223
+
224
+ # https://en.wikipedia.org/wiki/Dynamic_time_warping
225
+ def get_dtw(dist_matrix):
226
+ r,c = dist_matrix.shape
227
+ #dtw = np.zeros((r+1, c+1))
228
+ #for i in range(dtw.shape[0]):
229
+ #for j in range(dtw.shape[1]):
230
+ # dtw[i, j] = np.inf
231
+
232
+ dtw = np.full((r+1, c+1), np.inf)
233
+ dtw[0, 0] = 0
234
+
235
+ for i in range(1, dtw.shape[0]):
236
+ for j in range(1, dtw.shape[1]):
237
+ dtw[i, j] = dist_matrix[i-1, j-1] + np.min([dtw[i-1, j], dtw[i, j-1],
238
+ dtw[i-1, j-1]])
239
+ return dtw[-1, -1]
240
+
241
+
242
+
243
+
244
+
245
+ def get_cers(gt_list, prediction_list, no_diacritics=False):
246
+
247
+ if no_diacritics:
248
+ gt_list = [remove_diacritics(g) for g in gt_list]
249
+ prediction_list = [remove_diacritics(p) for p in prediction_list]
250
+
251
+ cer_matrix = get_cer_matrix(gt_list, prediction_list)
252
+ cer_gt = cer_matrix.min(axis=1).flatten()
253
+ cer_p = cer_matrix.min(axis=0).flatten()
254
+ cer_dtw = get_dtw(cer_matrix)
255
+
256
+ gt = '\n'.join(gt_list)
257
+ pred = '\n'.join(prediction_list)
258
+ cer_para = error_rates.cer(gt, pred)
259
+
260
+ return cer_dtw, np.sum(cer_gt), np.sum(cer_p), cer_para
261
+
262
+ def get_cers_wer(gt_list, prediction_list, no_diacritics=False):
263
+
264
+ if no_diacritics:
265
+ gt_list = [remove_diacritics(g) for g in gt_list]
266
+ prediction_list = [remove_diacritics(p) for p in prediction_list]
267
+
268
+ cer_matrix = get_cer_matrix(gt_list, prediction_list)
269
+ cer_gt = cer_matrix.min(axis=1).flatten()
270
+ cer_p = cer_matrix.min(axis=0).flatten()
271
+ cer_dtw = get_dtw(cer_matrix)
272
+
273
+ gt = '\n'.join(gt_list)
274
+ pred = '\n'.join(prediction_list)
275
+ cer_para = error_rates.cer(gt, pred)
276
+ wer_para = error_rates.wer(gt, pred)
277
+
278
+ return cer_dtw, np.sum(cer_gt), np.sum(cer_p), cer_para, wer_para
279
+
280
+ def flip_polygon(img_file, poly_list):
281
+ img = cv2.imread(img_file)
282
+ h, w = img.shape[:2]
283
+ flipped_poly_list = []
284
+ for p in poly_list:
285
+ flipped = [(w-x, y) for (x, y) in p]
286
+ flipped_poly_list.append(flipped)
287
+ return flipped_poly_list
288
+
289
+ # This will create annotations that can be used in scribeArabic annotation tool
290
+ def create_annotations_json(predicted_text, poly_list):
291
+ if len(predicted_text) != len(poly_list):
292
+ print("POLYGONG LIST LEN Not same as PREDICTED TEXT LEN")
293
+ slkfj
294
+ page_json = dict()
295
+ for ind, (ocr, poly) in enumerate(zip(predicted_text, poly_list)):
296
+ pts = points.xy_to_list(poly)
297
+ line_key = f'line_{ind+1}'
298
+ page_json[line_key] = dict()
299
+ page_json[line_key]['coord'] = pts
300
+ page_json[line_key]['text'] = ocr
301
+ return page_json
arabic/test_hw_helper_routines.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('py3/')
3
+ import os
4
+ import json
5
+ import yaml
6
+ import pandas as pd
7
+ import numpy as np
8
+ import torch
9
+ import hw
10
+ from hw import cnn_lstm
11
+ from utils import string_utils, error_rates
12
+ import cv2
13
+
14
+ HT = 60
15
+
16
+ def load_char_set(char_set_path):
17
+ with open(char_set_path) as f:
18
+ char_set = json.load(f)
19
+
20
+ idx_to_char = {}
21
+ for k,v in char_set['idx_to_char'].items():
22
+ idx_to_char[int(k)] = v
23
+ return idx_to_char
24
+
25
+
26
+
27
+ def get_config(config_file):
28
+ with open(config_file) as f:
29
+ config = yaml.load(f, Loader=yaml.loader.SafeLoader)
30
+ return config
31
+
32
+
33
+
34
+ def load_HW(hw_network_config, pt_filename):
35
+ HW = cnn_lstm.create_model(hw_network_config)
36
+ hw_state = torch.load(pt_filename)
37
+ HW.load_state_dict(hw_state)
38
+
39
+ device = torch.device("cuda")
40
+ HW.to(device)
41
+ return HW
42
+
43
+ def get_predicted_str(HW, img_file, idx_to_char, device="cuda", flip=False,
44
+ show=False, read_image=True, img=None, tokenizer=None):
45
+
46
+
47
+ device = torch.device(device)
48
+ if read_image:
49
+ img = cv2.imread(img_file)
50
+ ht, width = img.shape[:2]
51
+
52
+ if ht != HT:
53
+ new_width = int(width/ht*HT)
54
+ img = cv2.resize(img, (new_width, HT))
55
+ if show:
56
+
57
+ plt.imshow(img)
58
+ plt.show()
59
+ if flip:
60
+ img = np.flip(img, axis=1)
61
+ img = img.astype(np.float32)
62
+ img = img / 128.0 - 1.0
63
+ img = np.expand_dims(img, 0)
64
+ img = img.transpose([0,3,1,2])
65
+ img = torch.from_numpy(img)
66
+
67
+
68
+ IMG = img.to(device)
69
+ #print('img size is', img.size())
70
+
71
+ preds = HW(IMG).cpu()
72
+
73
+ out = preds.permute(1,0,2)
74
+ out = out.data.numpy()
75
+ logits = out[0,...]
76
+ pred, raw_pred = string_utils.naive_decode(logits)
77
+ if tokenizer is None:
78
+ pred_str = string_utils.label2str_single(pred, idx_to_char, False)
79
+ else:
80
+ pred_str = tokenizer.decode(pred)
81
+
82
+ del IMG
83
+ return pred_str
84
+
85
+
86
+ def write_csv_all_predictions(config_file, suffix="", device="cuda", flip=False, pt_file='hw.pt',
87
+ test_file_to_use="", result_file="", tokenizer=None):
88
+
89
+ result_df = pd.DataFrame(columns=["image", "ground_truth", "prediction", "CER", "WER"])
90
+ config = get_config(config_file)
91
+ idx_to_char = load_char_set(config['network']['hw']['char_set_path'])
92
+ if 'hw_to_save' in config['pretraining'].keys():
93
+ pt_file = config['pretraining']['hw_to_save']
94
+ else:
95
+ pt_file = pt_file
96
+ pt_filename = os.path.join(config['pretraining']['snapshot_path'], pt_file)
97
+
98
+ config["network"]["hw"]["num_of_outputs"] = len(idx_to_char) + 1
99
+ if tokenizer is not None:
100
+ config["network"]["hw"]["num_of_outputs"] = tokenizer.get_vocab_size()
101
+ pt_filename = os.path.join(config['pretraining']['snapshot_path'], f"hw_tokenizer_{tokenizer.get_vocab_size()}.pt")
102
+
103
+ if len(suffix) > 0:
104
+ pt_filename = pt_filename[:-3] + suffix + '.pt'
105
+ print('...Using snapshot', pt_filename)
106
+ HW = load_HW(config['network']['hw'], pt_filename)
107
+ device = torch.device(device)
108
+ HW.to(device)
109
+ HW.eval()
110
+
111
+ if test_file_to_use == "":
112
+ test_json_file = config['testing']['test_file']
113
+ else:
114
+ test_json_file = test_file_to_use
115
+ print('Using test file', test_json_file)
116
+
117
+ with open(test_json_file) as f:
118
+ json_obj = json.load(f)
119
+ for ind, obj in enumerate(json_obj):
120
+ #obj is a list of: [jsonfile imgfile]
121
+ #open the json file and get a list of predictions
122
+ with open(obj[0]) as f:
123
+ image_list = json.load(f)
124
+
125
+ for record in image_list:
126
+ # If not gt in file
127
+ if record['gt'] == 'None' or isinstance(record['gt'], float) or record['gt'] == 'nan' or len(record['gt']) == 0:
128
+ print(type(record['gt']), record['gt'], record['hw_path'])
129
+ continue
130
+ #print(record['hw_path'])
131
+ predicted_str = get_predicted_str(HW, record['hw_path'], idx_to_char, device=device,
132
+ flip=flip, tokenizer=tokenizer)
133
+
134
+ cer = error_rates.cer(record['gt'], predicted_str)
135
+ wer = error_rates.wer(record['gt'], predicted_str)
136
+ result_df.loc[len(result_df)] = [record['hw_path'], record['gt'],
137
+ predicted_str, cer, wer]
138
+
139
+ if len(result_file) == 0:
140
+ result_file = config_file.replace("config", "result")
141
+ result_file = result_file.replace("yaml", "csv")
142
+ if len(suffix)>0:
143
+ result_file = result_file[:-4] + suffix + '.csv'
144
+ result_df.to_csv(result_file, index=False)
145
+ return result_df
146
+
147
+
arabic/warp_routines.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from svgpathtools import Path, Line
2
+ from scipy.interpolate import griddata
3
+ import numpy as np
4
+ import cv2
5
+ import sys
6
+ sys.path.append('../../coords/')
7
+ import points
8
+ import torch
9
+
10
+ def generate_offset_mapping(img, ts, path, offset_1, offset_2, max_min = None, cube_size = None):
11
+ # cube_size = 80
12
+
13
+ offset_1_pts = []
14
+ offset_2_pts = []
15
+ # for t in ts:
16
+ for i in range(len(ts)):
17
+ t = ts[i]
18
+ pt = path.point(t)
19
+
20
+ norm = None
21
+ if i == 0:
22
+ norm = normal(pt, path.point(ts[i+1]))
23
+ norm = norm / dis(complex(0,0), norm)
24
+ elif i == len(ts)-1:
25
+ norm = normal(path.point(ts[i-1]), pt)
26
+ norm = norm / dis(complex(0,0), norm)
27
+ else:
28
+ norm1 = normal(path.point(ts[i-1]), pt)
29
+ norm1 = norm1 / dis(complex(0,0), norm1)
30
+ norm2 = normal(pt, path.point(ts[i+1]))
31
+ norm2 = norm2 / dis(complex(0,0), norm2)
32
+
33
+ norm = (norm1 + norm2)/2
34
+ norm = norm / dis(complex(0,0), norm)
35
+
36
+ offset_vector1 = offset_1 * norm
37
+ offset_vector2 = offset_2 * norm
38
+
39
+ pt1 = pt + offset_vector1
40
+ pt2 = pt + offset_vector2
41
+
42
+ offset_1_pts.append(complexToNpPt(pt1))
43
+ offset_2_pts.append(complexToNpPt(pt2))
44
+
45
+ offset_1_pts = np.array(offset_1_pts)
46
+ offset_2_pts = np.array(offset_2_pts)
47
+
48
+ h,w = img.shape[:2]
49
+
50
+ offset_source2 = np.array([(cube_size*i, 0) for i in range(len(offset_1_pts))], dtype=np.float32)
51
+ offset_source1 = np.array([(cube_size*i, cube_size) for i in range(len(offset_2_pts))], dtype=np.float32)
52
+
53
+ offset_source1 = offset_source1[::-1]
54
+ offset_source2 = offset_source2[::-1]
55
+
56
+ source = np.concatenate([offset_source1, offset_source2])
57
+ destination = np.concatenate([offset_1_pts, offset_2_pts])
58
+
59
+ source = source[:,::-1]
60
+ destination = destination[:,::-1]
61
+
62
+ n_w = int(offset_source2[:,0].max())
63
+ n_h = int(cube_size)
64
+
65
+ grid_x, grid_y = np.mgrid[0:n_h, 0:n_w]
66
+
67
+ grid_z = griddata(source, destination, (grid_x, grid_y), method='cubic')
68
+ map_x = np.append([], [ar[:,1] for ar in grid_z]).reshape(n_h,n_w)
69
+ map_y = np.append([], [ar[:,0] for ar in grid_z]).reshape(n_h,n_w)
70
+ map_x_32 = map_x.astype('float32')
71
+ map_y_32 = map_y.astype('float32')
72
+
73
+ rectified_to_warped_x = map_x_32
74
+ rectified_to_warped_y = map_y_32
75
+
76
+ grid_x, grid_y = np.mgrid[0:h, 0:w]
77
+ grid_z = griddata(source, destination, (grid_x, grid_y), method='cubic')
78
+ map_x = np.append([], [ar[:,1] for ar in grid_z]).reshape(h,w)
79
+ map_y = np.append([], [ar[:,0] for ar in grid_z]).reshape(h,w)
80
+ map_x_32 = map_x.astype('float32')
81
+ map_y_32 = map_y.astype('float32')
82
+
83
+ warped_to_rectified_x = map_x_32
84
+ warped_to_rectified_y = map_y_32
85
+
86
+ return rectified_to_warped_x, rectified_to_warped_y, warped_to_rectified_x, warped_to_rectified_y, max_min
87
+
88
+
89
+ def dis(pt1, pt2):
90
+ a = (pt1.real - pt2.real)**2
91
+ b = (pt1.imag - pt2.imag)**2
92
+ return np.sqrt(a+b)
93
+
94
+ def complexToNpPt(pt):
95
+ return np.array([pt.real, pt.imag], dtype=np.float32)
96
+
97
+ def normal(pt1, pt2):
98
+ dif = pt1 - pt2
99
+ return complex(-dif.imag, dif.real)
100
+
101
+ def find_t_spacing(path, cube_size):
102
+
103
+
104
+ l = path.length()
105
+ error = 0.01
106
+ init_step_size = cube_size / l
107
+
108
+ last_t = 0
109
+ cur_t = 0
110
+ pts = []
111
+ ts = [0]
112
+ pts.append(complexToNpPt(path.point(cur_t)))
113
+ path_lookup = {}
114
+ for target in np.arange(cube_size, int(l), cube_size):
115
+ step_size = init_step_size
116
+ for i in range(1000):
117
+ cur_length = dis(path.point(last_t), path.point(cur_t))
118
+ if np.abs(cur_length - cube_size) < error:
119
+ break
120
+
121
+ step_t = min(cur_t + step_size, 1.0)
122
+ step_l = dis(path.point(last_t), path.point(step_t))
123
+
124
+ if np.abs(step_l - cube_size) < np.abs(cur_length - cube_size):
125
+ cur_t = step_t
126
+ continue
127
+
128
+ step_t = max(cur_t - step_size, 0.0)
129
+ step_t = max(step_t, last_t)
130
+ step_t = max(step_t, 1.0)
131
+
132
+ step_l = dis(path.point(last_t), path.point(step_t))
133
+
134
+ if np.abs(step_l - cube_size) < np.abs(cur_length - cube_size):
135
+ cur_t = step_t
136
+ continue
137
+
138
+ step_size = step_size / 2.0
139
+
140
+ last_t = cur_t
141
+
142
+ ts.append(cur_t)
143
+ pts.append(complexToNpPt(path.point(cur_t)))
144
+
145
+ pts = np.array(pts)
146
+
147
+ return ts
148
+
149
+
150
+ def remap_with_grid_sample(input_image, map_x, map_y, padding, img_tensor=None, device="cuda"):
151
+
152
+
153
+
154
+ reshaped = False
155
+ H, W = input_image.shape[:2]
156
+
157
+ if len(input_image.shape) == 2:
158
+ #print('reshaping', input_image.shape)
159
+ input_image = input_image.reshape(H, W, 1)
160
+ reshaped = True
161
+
162
+ if img_tensor is None:
163
+ # Convert input image to PyTorch tensor in NCHW format and normalize to [0, 1]
164
+ img_tensor = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0).float() / 255.0
165
+
166
+ img_tensor = img_tensor.to(device)
167
+
168
+ #print(input_image.shape, map_x.shape)
169
+
170
+ # Convert map_x and map_y to normalized coordinates in the range [-1, 1]
171
+ norm_map_x = (torch.from_numpy(map_x.copy()).float() / (W - 1)) * 2 - 1
172
+ norm_map_y = (torch.from_numpy(map_y.copy()).float() / (H - 1)) * 2 - 1
173
+
174
+
175
+ # Stack normalized coordinates to create a grid of shape (1, H, W, 2)
176
+ grid = torch.stack((norm_map_x, norm_map_y), dim=-1).unsqueeze(0)
177
+
178
+ # Ensure grid is on the same device as the input tensor (e.g., GPU if available)
179
+ grid = grid.to(img_tensor.device)
180
+
181
+ # Apply grid_sample to perform the remap operation
182
+ output_tensor = torch.nn.functional.grid_sample(img_tensor, grid, mode='bilinear',
183
+ padding_mode=padding, align_corners=True)
184
+
185
+
186
+ # Convert back to NumPy and scale back to [0, 255]
187
+ output_image = (output_tensor.squeeze(dim=0).permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
188
+ if reshaped:
189
+ output_image = output_image[:, :, 0]
190
+ return output_image
191
+
192
+
193
+ def generate_offset_mapping_1(img, ts, path, offset_1, offset_2,
194
+ cube_size = None):
195
+
196
+
197
+ offset_1_pts = []
198
+ offset_2_pts = []
199
+ # for t in ts:
200
+ for i in range(len(ts)):
201
+ t = ts[i]
202
+ pt = path.point(t)
203
+
204
+ norm = None
205
+ if i == 0:
206
+ norm = normal(pt, path.point(ts[i+1]))
207
+ norm = norm / dis(complex(0,0), norm)
208
+ elif i == len(ts)-1:
209
+ norm = normal(path.point(ts[i-1]), pt)
210
+ norm = norm / dis(complex(0,0), norm)
211
+ else:
212
+ norm1 = normal(path.point(ts[i-1]), pt)
213
+ norm1 = norm1 / dis(complex(0,0), norm1)
214
+ norm2 = normal(pt, path.point(ts[i+1]))
215
+ norm2 = norm2 / dis(complex(0,0), norm2)
216
+
217
+ norm = (norm1 + norm2)/2
218
+ norm = norm / dis(complex(0,0), norm)
219
+
220
+ offset_vector1 = offset_1 * norm
221
+ offset_vector2 = offset_2 * norm
222
+
223
+ pt1 = pt + offset_vector1
224
+ pt2 = pt + offset_vector2
225
+
226
+ offset_1_pts.append(complexToNpPt(pt1))
227
+ offset_2_pts.append(complexToNpPt(pt2))
228
+
229
+ offset_1_pts = np.array(offset_1_pts)
230
+ offset_2_pts = np.array(offset_2_pts)
231
+
232
+ h,w = img.shape[:2]
233
+
234
+
235
+ offset_source2 = np.array([(cube_size*i, 0) for i in range(len(offset_1_pts))], dtype=np.float32)
236
+ offset_source1 = np.array([(cube_size*i, cube_size) for i in range(len(offset_2_pts))], dtype=np.float32)
237
+
238
+ offset_source1 = offset_source1[::-1]
239
+ offset_source2 = offset_source2[::-1]
240
+
241
+ source = np.concatenate([offset_source1, offset_source2])
242
+ destination = np.concatenate([offset_1_pts, offset_2_pts])
243
+
244
+ source = source[:,::-1]
245
+ destination = destination[:,::-1]
246
+
247
+ n_w = int(offset_source2[:,0].max())
248
+ n_h = int(cube_size)
249
+
250
+ grid_x, grid_y = np.mgrid[0:n_h, 0:n_w]
251
+
252
+
253
+ grid_z = griddata(source, destination, (grid_x, grid_y), method='cubic')
254
+ map_x = np.append([], [ar[:,1] for ar in grid_z]).reshape(n_h,n_w)
255
+ map_y = np.append([], [ar[:,0] for ar in grid_z]).reshape(n_h,n_w)
256
+ map_x_32 = map_x.astype('float32')
257
+ map_y_32 = map_y.astype('float32')
258
+
259
+ rectified_to_warped_x = map_x_32
260
+ rectified_to_warped_y = map_y_32
261
+
262
+
263
+
264
+ return rectified_to_warped_x, rectified_to_warped_y
265
+
266
+
267
+ def get_warped_images(img, polygon_list, baseline_list, target_height=60, device="cuda"):
268
+
269
+
270
+ num_lines = len(polygon_list)
271
+ all_lines = ""
272
+
273
+ warped_list = []
274
+ region_output_data = []
275
+
276
+ # img_tensor = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0).float() / 255.0
277
+ # img_tensor = img_tensor.to(device)
278
+
279
+
280
+ for ind in range(len(polygon_list)):
281
+
282
+ line_mask = extract_region_mask(img, polygon_list[ind])
283
+
284
+ summed_axis0 = (line_mask.astype(float) / 255).sum(axis=0)
285
+ avg_height0 = np.median(summed_axis0[summed_axis0 != 0])
286
+
287
+ target_step_size = avg_height0*1.1
288
+
289
+
290
+ paths = []
291
+ for i in range(len(baseline_list[ind])-1):
292
+ i_1 = i+1
293
+
294
+ p1 = baseline_list[ind][i]
295
+ p2 = baseline_list[ind][i_1]
296
+
297
+ p1_c = complex(*p1)
298
+ p2_c = complex(*p2)
299
+
300
+ paths.append(Line(p1_c, p2_c))
301
+
302
+ if len(paths) == 0:
303
+ continue
304
+
305
+ #try:
306
+ if True:
307
+ # Add a bit on the end
308
+ tan = paths[-1].unit_tangent(1.0)
309
+ p3_c = p2_c + target_step_size * tan
310
+ paths.append(Line(p2_c, p3_c))
311
+
312
+ path = Path(*paths)
313
+
314
+ n_w = target_height*path.length()/target_step_size
315
+ ts = np.arange(0, 1, target_height/float(n_w))
316
+
317
+
318
+
319
+ (rectified_to_warped_x,
320
+ rectified_to_warped_y) = generate_offset_mapping_1(img, ts, path, target_step_size/2,
321
+ -target_step_size/2,
322
+ cube_size=target_height)
323
+
324
+ rectified_to_warped_x = rectified_to_warped_x[::-1,::-1]
325
+ rectified_to_warped_y = rectified_to_warped_y[::-1,::-1]
326
+
327
+ #warped = remap_with_grid_sample(img, rectified_to_warped_x,
328
+ # rectified_to_warped_y, "border", img_tensor=img_tensor, device=device)
329
+
330
+ warped = cv2.remap(img, rectified_to_warped_x, rectified_to_warped_y, cv2.INTER_CUBIC, borderValue=(255,255,255))
331
+ warped_list.append(warped)
332
+
333
+ return warped_list
334
+
335
+
336
+
337
+
338
+ def extract_region_mask(img, bounding_poly):
339
+ pts = np.array(bounding_poly, np.int32)
340
+
341
+ #http://stackoverflow.com/a/15343106/3479446
342
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
343
+ roi_corners = np.array([pts], dtype=np.int32)
344
+
345
+ ignore_mask_color = (255,)
346
+ cv2.fillPoly(mask, roi_corners, ignore_mask_color, lineType=cv2.LINE_8)
347
+ return mask
348
+
349
+
350
+
351
+ def get_baseline(poly_points, right_to_left=True):
352
+
353
+ baseline = points.get_baseline_chunks(poly_points)
354
+ # Right to left order
355
+ #if right_to_left:
356
+ # baseline.sort(key=lambda x: x[0], reverse=True)
357
+ return baseline
358
+
359
+ # First argument ignored if third argument is given
360
+ def get_line_image(polygon_flat_list, img, polygon_pts=None, target_height=60):
361
+ if polygon_pts is None:
362
+ polygon_pts = points.list_to_xy(polygon_flat_list)
363
+ baseline = get_baseline(polygon_pts)
364
+ line_img = get_warped_images(img, [polygon_pts], [baseline], target_height=target_height)
365
+ if len(line_img) > 0:
366
+ return line_img[0]
367
+ return None
368
+
coords/__init__.py ADDED
File without changes
coords/points.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from matplotlib.patches import Rectangle
3
+ import numpy as np
4
+ from PIL import Image, ImageDraw
5
+
6
+
7
+ # img_file is the filename with full path
8
+ # points is a string of coordinate points read from XML
9
+ def generate_cropped_region(img_file, points):
10
+ # Process points
11
+ points_list = points.split(' ')
12
+ xy_pts = [(int(points.split(",")[0]),
13
+ int(points.split(",")[1])) for points in points_list]
14
+ pts = np.array(xy_pts)
15
+ img_obj = Image.open(img_file)
16
+ img = np.array(img_obj)
17
+ # Crop the image
18
+ [min_x, min_y] = np.min(pts, axis=0)
19
+ [max_x, max_y] = np.max(pts, axis=0)
20
+ cropped_img = img[min_y:max_y+1, min_x:max_x+1, :]
21
+ cropped_img_obj = Image.fromarray(cropped_img)
22
+ return cropped_img_obj, (min_x, min_y), (max_x, max_y)
23
+
24
+ # pts is numpy 2D points array
25
+ # img also numpy arrray/cv2 array
26
+ def generate_cropped_image(img, pts):
27
+ img_obj = Image.fromarray(img)
28
+
29
+ # Create a polygonal mask
30
+ mask = Image.new('L', (img_obj.width, img_obj.height), color=0)
31
+ draw_mask = ImageDraw.Draw(mask)
32
+ draw_mask.polygon(list(pts.flatten()), fill=255)
33
+ mask = np.array(mask).astype(bool)
34
+ # Choose the polygonal area from image
35
+ output_img = np.zeros_like(img)
36
+ output_img[mask] = img[mask]
37
+ # Crop the image
38
+ [min_x, min_y] = np.min(pts, axis=0).astype(int)
39
+ [max_x, max_y] = np.max(pts, axis=0).astype(int)
40
+ cropped_img = output_img[min_y:max_y+1, min_x:max_x+1, :]
41
+
42
+ return cropped_img
43
+
44
+
45
+ def draw_poly(plt, xy_pts, color='green'):
46
+
47
+ plt.gca().add_patch(Rectangle(xy_pts[0], 10, 10, facecolor='yellow'))
48
+ for pts1, pts2 in zip(xy_pts, xy_pts[1:]):
49
+ #img = color_rect(img,pts1[0], pts1[1], pts2[0], pts2[1])
50
+ draw_line(plt, pts1[0], pts1[1], pts2[0], pts2[1], color)
51
+ draw_line(plt, xy_pts[0][0], xy_pts[0][1],
52
+ xy_pts[-1][0], xy_pts[-1][1], color)
53
+
54
+ def draw_baseline(plt, xy_pts, color='red'):
55
+
56
+ plt.gca().add_patch(Rectangle(xy_pts[0], 100, 100, facecolor='blue'))
57
+ for pts1, pts2 in zip(xy_pts, xy_pts[1:]):
58
+ #img = color_rect(img,pts1[0], pts1[1], pts2[0], pts2[1])
59
+ draw_line(plt, pts1[0], pts1[1], pts2[0], pts2[1], color=color)
60
+
61
+ def draw_line(plt_obj, x1, y1, x2, y2, color='g'):
62
+ plt_obj.plot([x1, x2], [y1, y2], color=color, linewidth=1)
63
+
64
+ # The argument points is a string.
65
+ # Function returns a list of (x,y) tuples
66
+ def get_xy_pts(points):
67
+ points_list = points.split(' ')
68
+ xy_pts = [(int(points.split(",")[0]),
69
+ int(points.split(",")[1])) for points in points_list]
70
+ return xy_pts
71
+
72
+ # bbox is not necessarily a polygon or rectangle. Just a list of (x,y) tuples
73
+ # If apply_correction is True then all points are restricted to lie within
74
+ # top left and bottom right
75
+ def add_offset_to_polygon(bbox, offset, apply_correction=False,
76
+ top_left=[], bottom_right=[]):
77
+ new_bbox = []
78
+ for i,coord in enumerate(bbox):
79
+ new_bbox.append((coord[0]+offset[0], coord[1]+offset[1]))
80
+ if apply_correction:
81
+ top_left = np.array(top_left)
82
+ bottom_right = np.array(bottom_right)
83
+ pts = np.array(new_bbox)
84
+ for j in [0, 1]:
85
+ ind = np.where(pts[:, j] > bottom_right[j])
86
+ pts[ind, j] = bottom_right[j]
87
+ for j in [0, 1]:
88
+ ind = np.where(pts[:, j] < top_left[j])
89
+ pts[ind, j] = top_left[j]
90
+ new_bbox = list(map(tuple, pts))
91
+
92
+ return new_bbox
93
+
94
+ def add_offset_to_polygon_list(polygon_list, offset):
95
+ new_polygon_list = []
96
+ for poly in polygon_list:
97
+ new_poly = add_offset_to_polygon(poly, offset)
98
+ new_polygon_list.append(new_poly)
99
+ return new_polygon_list
100
+
101
+ def combine_poly(poly1, poly2):
102
+ main_poly = [poly1[0], poly1[1]]
103
+ if poly1[1][0] != poly2[0][0] or poly1[1][1] != poly2[0][1]:
104
+ main_poly.append(poly2[0])
105
+ main_poly.extend(poly2[1:3])
106
+ if poly1[2][0] != poly2[3][0] or poly1[2][1] != poly2[3][1]:
107
+ main_poly.append(poly2[3])
108
+ main_poly.extend(poly1[2:])
109
+ return main_poly
110
+
111
+ # Get left, upper, right, lower min_x, min_y, max_x, max_y
112
+ def get_max_min_polygon(polygon):
113
+ x_list = [x[0] for x in polygon]
114
+ y_list = [y[1] for y in polygon]
115
+ return min(x_list), min(y_list), max(x_list), max(y_list)
116
+
117
+ def add_offset_to_baseline(baseline, offset):
118
+ new_baseline = []
119
+ for pts in baseline:
120
+ new_baseline.append((pts[0]+offset[0], pts[1]+offset[1]))
121
+ return new_baseline
122
+
123
+
124
+ def add_offset_to_baseline_list(baseline_list, offset):
125
+ new_list = []
126
+ for b in baseline_list:
127
+ new_b = add_offset_to_baseline(b, offset)
128
+ new_list.append(new_b)
129
+ return new_list
130
+
131
+ def combine_baseline(base1, base2):
132
+ #combined = [base1[0], base1[1], base2[0], base2[1]]
133
+ combined = [base2[1], base2[0], base1[1], base1[0]]
134
+ return combined
135
+
136
+
137
+ def get_x_y(polygon):
138
+ x_list = [x[0] for x in polygon]
139
+ y_list = [y[1] for y in polygon]
140
+ return x_list, y_list
141
+
142
+
143
+ # num_pts is number of points on baseline to get
144
+ def get_baseline_regression(poly_pts, num_pts=10, deg=1):
145
+ if len(poly_pts) <= 4:
146
+ deg = 1
147
+ x, y = get_x_y(poly_pts)
148
+ model = (np.polyfit(x, y, deg))
149
+ p = np.poly1d(model)
150
+ # get the x against which we want y
151
+ x1, y1, x2, y2 = get_max_min_polygon(poly_pts)
152
+ num_pts = min(num_pts, x2-x1+1)
153
+ num_pts = int(num_pts)
154
+ x = np.linspace(x1, x2, num_pts, endpoint=True, dtype=int)
155
+ y = p(x)
156
+ baseline = [(a, b) for a,b in zip(x, y)]
157
+ return baseline
158
+
159
+
160
+
161
+ # Given a coordinates list return xy tuples
162
+ def list_to_xy(coord_list):
163
+ xy_list = []
164
+ for ind in range(0, len(coord_list), 2):
165
+ xy_list.append((coord_list[ind], coord_list[ind+1]))
166
+ return xy_list
167
+
168
+ # Given an (x,y) list of tuples, return a flat list
169
+ def xy_to_list(tuples_list):
170
+ flat_list = [x for pair in tuples_list for x in pair]
171
+ return flat_list
172
+
173
+
174
+ # x is a list of points
175
+ # y is a cooresponding list of points
176
+ def get_baseline_from_xy(x, y, num_pts=10, deg=1):
177
+ if len(x) <= 4:
178
+ deg = 1
179
+
180
+ model = (np.polyfit(x, y, deg))
181
+ p = np.poly1d(model)
182
+ # get the x against which we want y
183
+ x1, y1, x2, y2 = get_max_min_polygon(poly_pts)
184
+ num_pts = min(num_pts, x2-x1+1)
185
+ x = np.linspace(x1, x2, num_pts, endpoint=True, dtype=int)
186
+ y = p(x)
187
+ baseline = [(a, b) for a,b in zip(x, y)]
188
+ return baseline
189
+
190
+ # Here img is the 2D numpy array
191
+ # xy_pts is a list of (x,y) tuples
192
+ def generate_cropped_region_from_polypts(img, xy_pts):
193
+ pts = np.array(xy_pts)
194
+
195
+ # Crop the image
196
+ [min_x, min_y] = np.ceil(np.min(pts, axis=0)).astype(int)
197
+ [max_x, max_y] = np.floor(np.max(pts, axis=0)).astype(int)
198
+ cropped_img = img[min_y:max_y+1, min_x:max_x+1, :]
199
+ cropped_img_obj = Image.fromarray(cropped_img)
200
+ return cropped_img_obj, (min_x, min_y), (max_x, max_y)
201
+
202
+ # Will generate a line image by using xy_pts as a mask
203
+ def generate_line_image(img, xy_pts):
204
+
205
+ pts = np.array(xy_pts)
206
+ [min_x, min_y] = np.ceil(np.min(pts, axis=0)).astype(int)
207
+ [max_x, max_y] = np.floor(np.max(pts, axis=0)).astype(int)
208
+ (width, ht) = (max_x-min_x+1, max_y-min_y+1)
209
+ xy_pts = add_offset_to_polygon(xy_pts, (-min_x, -min_y))
210
+ img = img[min_y:max_y+1, min_x:max_x+1, :]
211
+ #print('min_y:max_y+1, min_x:max_x+1, :', min_y, max_y+1, min_x, max_x+1)
212
+ img_obj = Image.fromarray(img)
213
+
214
+ draw_img = ImageDraw.Draw(img_obj)
215
+ # Create a polygonal mask
216
+ mask = Image.fromarray(np.zeros((img.shape[0], img.shape[1])))
217
+ draw_mask = ImageDraw.Draw(mask)
218
+ draw_mask.polygon(xy_pts, fill='white')
219
+ mask = np.array(mask).astype(bool)
220
+ # Choose the polygonal area from image
221
+
222
+ output_img = np.zeros_like(img)+255
223
+ output_img[mask] = img[mask]
224
+ output_img = Image.fromarray(output_img)
225
+ return output_img, xy_pts
226
+
227
+
228
+ # restrict coordinates to lie between 0 and max (included)
229
+ def restrict_pts(pts, max_p):
230
+ pts = [(max(0, x), max(0, y)) for (x,y) in pts]
231
+ pts = [(min(max_p[0], x), min(max_p[1], y)) for (x,y) in pts]
232
+ return pts
233
+
234
+ # assuming poly_pts is a list of (x,y) tuples
235
+ # This will add more points by interpolating between two points
236
+ def expand_poly(poly_pts, min_x_increment=10):
237
+ poly_pts = np.array(poly_pts).astype(int)
238
+ new_poly = []
239
+ # for ind, (curr, nxt) in enumerate(zip(poly_pts[:-1], poly_pts[1:])):
240
+
241
+ for ind, curr in enumerate(poly_pts):
242
+ nxt = poly_pts[(ind+1)%len(poly_pts)]
243
+ if np.abs(nxt[0] - curr[0]) < min_x_increment:
244
+ new_poly.append(curr)
245
+ new_poly.append(nxt)
246
+ continue
247
+ x1, x2 = curr[0], nxt[0]
248
+ y1, y2 = curr[1], nxt[1]
249
+ new_poly.append(curr)
250
+ increment = -1*min_x_increment if nxt[0] < curr[0] else min_x_increment
251
+ for x in range(curr[0]+1, nxt[0], increment):
252
+ slope = float(y2-y1)/(x2-x1)
253
+ y = slope*(x - x2) + y2
254
+ new_poly.append((x, y))
255
+ new_poly.append(nxt)
256
+ return new_poly
257
+
258
+
259
+ # Chunks_len ignored if chunk_len_auto is True
260
+ def get_baseline_chunks(poly_pts, chunks_len=300, chunk_len_auto=True):
261
+ baseline = []
262
+ poly_pts = expand_poly(poly_pts)
263
+
264
+
265
+ p = np.array(poly_pts)
266
+
267
+ max_x, max_y = np.max(p, 0)
268
+ min_x, min_y = np.min(p, 0)
269
+
270
+ # Decide chunks_len
271
+ if chunk_len_auto:
272
+ if (len(poly_pts) >= 250):
273
+ total_chunks = 5
274
+ else:
275
+ total_chunks = int(np.ceil(len(poly_pts)/50))
276
+ chunks_len = int((max_x-min_x)/total_chunks)
277
+ else:
278
+ total_chunks = int((max_x-min_x)/chunks_len)
279
+
280
+ #print('expanded len', len(poly_pts), 'total chunks', total_chunks)
281
+ for i in range(1, total_chunks+1):
282
+ p1 = [pt for pt in p if (pt[0]-min_x)>=(i-1)*chunks_len and (pt[0]-min_x)<i*chunks_len]
283
+ #print(p1)
284
+ if i == total_chunks:
285
+ p1 = [pt for pt in p if (pt[0] - min_x)>=(i-1)*chunks_len]
286
+ b = get_baseline_regression(p1, num_pts=12)
287
+
288
+ # Points are in ascending order (increasing x - left to right)
289
+ if len(baseline) != 0:
290
+ # This will smooth out the line
291
+ # Get rid of last 4 points and connect the point with next 4 points
292
+ baseline = baseline[:-4]
293
+ baseline.extend(b[4:])
294
+ if i == total_chunks:
295
+ baseline.extend(b[-1:])
296
+ else:
297
+ baseline = b
298
+
299
+ # Make sure baseline does not repeat
300
+
301
+ prev_pt = baseline[0]
302
+ baseline_clean = [prev_pt]
303
+ for b in baseline[1:]:
304
+ if b != prev_pt:
305
+ baseline_clean.append(b)
306
+ prev_pt = b
307
+ return baseline_clean
308
+
309
+
310
+ # Making sure a value is not outside a boundary or has negative value
311
+ def correct_pt(value, max_value):
312
+ if value < 0:
313
+ return 0
314
+ if value > max_value:
315
+ return max_value
316
+ return value
317
+
318
+ # Assume poly is list of (x, y) tuples or [x, y] list
319
+ # Will retrieve a vertically oriented baseline
320
+ # Will return top to bottom and bottom to top if reversed is true
321
+ def get_vertical_baseline(poly, reversed=False):
322
+ flipped_poly = [(y, x) for (x, y) in poly]
323
+ baseline = get_baseline_chunks(flipped_poly)
324
+ # Flip back
325
+ baseline = [(y, x) for (x, y) in baseline]
326
+ if reversed:
327
+ baseline.sort(key=lambda x: x[1], reverse=True)
328
+ return baseline
329
+
330
+ # Check the polygon is valid
331
+ # If x coord or y coord don't change, its not valid
332
+ def valid_poly(poly_pts):
333
+ if len(poly_pts) <= 2:
334
+ return False
335
+ x = [pt[0] for pt in poly_pts]
336
+ y = [pt[1] for pt in poly_pts]
337
+
338
+ if np.max(x) - np.min(x) <= 1e-2:
339
+ return False
340
+ if np.max(y) - np.min(y) <= 1e-2:
341
+ return False
342
+ return True
coords/poly_routines.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import json
4
+ import sys
5
+
6
+ import points
7
+ from PIL import Image, ImageDraw
8
+ import os
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib.patches import Rectangle
11
+
12
+
13
+
14
+ def correct_pt(value, max_value):
15
+ boundary = False
16
+ if value < 0:
17
+ value = 0
18
+ boundary = True
19
+ if value >= max_value:
20
+ value = max_value - 1
21
+ boundary = True
22
+ return [value, boundary]
23
+
24
+ # Each polygon is a list of (x,y) tuples
25
+ def get_polygon_list_tuples(out):
26
+ img = cv2.imread(out["image_path"])
27
+ img_height, img_width = img.shape[:2]
28
+ polygon_list = []
29
+ prev = [-1, -1]
30
+ for line_ind in range(len(out['lf'][0])):
31
+ polygon = []
32
+ begin_ind = out['beginning'][line_ind]
33
+ end_ind = out['ending'][line_ind]
34
+ begin_ind = int(np.floor(begin_ind))
35
+ end_ind = int(np.ceil(end_ind))
36
+ end_ind = min(end_ind, len(out['lf'])-1)
37
+ for pt_ind in range(begin_ind, end_ind+1):
38
+ pt_x = float(out['lf'][pt_ind][line_ind][0][0])
39
+ pt_y = float(out['lf'][pt_ind][line_ind][1][0])
40
+ pt_x, boundary_x = correct_pt(pt_x, img_width)
41
+ pt_y, boundary_y = correct_pt(pt_y, img_height)
42
+ if prev != [pt_x, pt_y]:
43
+ polygon.append((pt_x, pt_y))
44
+ prev = [pt_x, pt_y]
45
+ for pt_ind in range(end_ind, begin_ind-1, -1):
46
+ pt_x = float(out['lf'][pt_ind][line_ind][0][1])
47
+ pt_y = float(out['lf'][pt_ind][line_ind][1][1])
48
+ pt_x, boundary_x = correct_pt(pt_x, img_width)
49
+ pt_y, boundary_y = correct_pt(pt_y, img_height)
50
+ if prev != [pt_x, pt_y]:
51
+ polygon.append((pt_x, pt_y))
52
+ prev = [pt_x, pt_y]
53
+
54
+ polygon_list.append(polygon)
55
+ if len(polygon) < 3:
56
+ print('WARNING: DEGENERATE POLYGON AT INDEX', len(polygon_list))
57
+ return polygon_list
58
+
59
+ # Each polygon is a list of (x,y) tuples
60
+ def get_polygon_list_without_trim(out):
61
+ img = cv2.imread(out["image_path"])
62
+ img_height, img_width = img.shape[:2]
63
+ polygon_list = []
64
+ for line_ind in range(len(out['lf'][0])):
65
+ polygon = []
66
+ begin_ind = 0
67
+ end_ind = len(out['lf'])-1
68
+ prev = [-1, -1]
69
+
70
+ for pt_ind in range(begin_ind, end_ind+1):
71
+ pt_x = float(out['lf'][pt_ind][line_ind][0][0])
72
+ pt_y = float(out['lf'][pt_ind][line_ind][1][0])
73
+ pt_x, boundary_x = correct_pt(pt_x, img_width)
74
+ pt_y, boundary_y = correct_pt(pt_y, img_height)
75
+ if prev != [pt_x, pt_y]:
76
+ polygon.append((pt_x, pt_y))
77
+ prev = [pt_x, pt_y]
78
+ for pt_ind in range(end_ind, begin_ind-1, -1):
79
+ pt_x = float(out['lf'][pt_ind][line_ind][0][1])
80
+ pt_y = float(out['lf'][pt_ind][line_ind][1][1])
81
+ pt_x, boundary_x = correct_pt(pt_x, img_width)
82
+ pt_y, boundary_y = correct_pt(pt_y, img_height)
83
+ if prev != [pt_x, pt_y]:
84
+ polygon.append((pt_x, pt_y))
85
+ prev = [pt_x, pt_y]
86
+
87
+ if len(polygon) >= 3:
88
+ polygon_list.append(polygon)
89
+ return polygon_list
90
+
91
+ # Each polygon passed as input is a list of (x,y) tuples
92
+ # Same for output
93
+ def percent_intersection(size, poly1, poly2):
94
+ im1 = Image.new(mode="1", size=size)
95
+ draw1 = ImageDraw.Draw(im1)
96
+ draw1.polygon(poly1, fill=1)
97
+ im2 = Image.new(mode="1", size=size)
98
+ draw2 = ImageDraw.Draw(im2)
99
+ draw2.polygon(poly2, fill=1)
100
+ mask1 = np.asarray(im1, dtype=bool)
101
+ mask2 = np.asarray(im2, dtype=bool)
102
+ intersection_mask = mask1 & mask2
103
+ #plt.imshow(intersection)
104
+ intersection_area = intersection_mask.sum()
105
+ percent1 = intersection_area / mask1.sum()
106
+ percent2 = intersection_area / mask2.sum()
107
+ return intersection_area, percent1, percent2
108
+
109
+
110
+
111
+ def get_poly_no_overlap(img_name, poly_list, threshold=0.6):
112
+
113
+ img=Image.open(img_name)
114
+ size=img.size
115
+ #polygons = [points.list_to_xy(p) for p in poly_list]
116
+ polygons = poly_list
117
+ del_list = []
118
+ current = 0
119
+ next_ind = current+1
120
+ last_deleted = -1
121
+ while next_ind<len(polygons):
122
+ # Check these are not degernate polygons
123
+ if len(polygons[current]) < 3:
124
+ del_list.append(current)
125
+ current, next_ind = (current+1, next_ind+1)
126
+ continue
127
+ if len(polygons[next_ind]) < 3:
128
+ del_list.append(next_ind)
129
+ next_ind += 1
130
+ continue
131
+ # End check
132
+ overlap_area, percent1, percent2 = percent_intersection(size,
133
+ polygons[current],
134
+ polygons[next_ind])
135
+
136
+ if percent1 > threshold or percent2 > threshold:
137
+ to_del = current if percent1 > percent2 else next_ind
138
+ current, next_ind = (current, next_ind+1) if percent1<percent2\
139
+ else (next_ind, next_ind+1)
140
+ del_list.append(to_del)
141
+ last_deleted = to_del
142
+ #print('last deleted', to_del)
143
+ else: # when no overlap is found
144
+ current, next_ind = (current+1, next_ind+1)
145
+ if current <= last_deleted:
146
+ current = last_deleted + 1
147
+ next_ind = current + 1
148
+ all_ind = set(range(len(poly_list)))
149
+ good_ind = all_ind.difference(set(del_list))
150
+ poly_non_overlapping = [poly_list[i] for i in good_ind]
151
+ return del_list, poly_non_overlapping
152
+
153
+
154
+
155
+ def dump_polygons_json(out, polygons = None, filename=None):
156
+ if filename is None:
157
+ filename = out["image_path"][:-3] + "json"
158
+ if polygons is None:
159
+ polygons = get_polygon_list(out)
160
+ lf_dict = {}
161
+ for ind, poly in enumerate(polygons):
162
+ lf_dict['line_' + str(ind+1)] = points.xy_to_list(poly)
163
+
164
+ with open(filename, 'w') as fout:
165
+ json_dumps_str = json.dumps(lf_dict, indent=2)
166
+ #print('....json_dumps_str', json_dumps_str)
167
+ print(json_dumps_str, file=fout)
168
+
169
+
170
+ # won't flip the polygons...only the image
171
+ def draw_image_with_poly(directory, image, poly, convert=True, flip=False):
172
+ img = cv2.imread(os.path.join(directory, image))
173
+ if flip:
174
+ img = cv2.flip(img, 1)
175
+ plt.imshow(img)
176
+ colors = ['red', 'green', 'blue']
177
+
178
+ for ind, p in enumerate(poly):
179
+ if convert:
180
+ p = points.list_to_xy(p)
181
+ points.draw_poly(plt, p, colors[ind%3])
182
+ plt.text(p[-1][0], p[-1][1], str(ind))
183
+
184
+
185
+
186
+ def flip_polygon(img_file, poly_list):
187
+ img = cv2.imread(img_file)
188
+ h, w = img.shape[:2]
189
+ flipped_poly_list = []
190
+ for p in poly_list:
191
+ flipped = [(w-x, y) for (x, y) in p]
192
+ flipped_poly_list.append(flipped)
193
+ return flipped_poly_list
194
+
195
+
coords/text_cleaning_routines.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from bidi.algorithm import get_display
3
+ import re
4
+
5
+ def correct_brackets(text):
6
+ text = switch_chars(text, '{', '}')
7
+ text = switch_chars(text, '(', ')')
8
+ text = switch_chars(text, '[', ']')
9
+ text = switch_chars(text, '«', '»')
10
+ return text
11
+
12
+ def switch_chars(text, x, y):
13
+ t = list(text)
14
+ ind_x = [i for i,j in enumerate(t) if j==x]
15
+ ind_y = [i for i,j in enumerate(t) if j==y]
16
+ for i in ind_x:
17
+ t[i] = y
18
+ for i in ind_y:
19
+ t[i] = x
20
+ return ''.join(t)
21
+
22
+ def clean_text(input_text):
23
+ cleaned_text = input_text.replace('\u0009', ' ')
24
+ cleaned_text = cleaned_text.replace('\u000A', ' ')
25
+ cleaned_text = cleaned_text.replace('\u00D7', 'x')
26
+ cleaned_text = cleaned_text.replace('\u066A', '%')
27
+ cleaned_text = cleaned_text.replace('\u06f3', '\u0663')
28
+ cleaned_text = cleaned_text.replace('\u06f7', '\u0667')
29
+ cleaned_text = cleaned_text.replace('\u06f9', '\u0669')
30
+
31
+
32
+ cleaned_text = cleaned_text.replace('\u2018', "'")
33
+ cleaned_text = cleaned_text.replace('\u2019', "'")
34
+ cleaned_text = cleaned_text.replace('\u201C', '"')
35
+ cleaned_text = cleaned_text.replace('\u201D', '"')
36
+ cleaned_text = cleaned_text.replace('…', '...')
37
+ cleaned_text = cleaned_text.replace('\u2033', "\u064b")
38
+ cleaned_text = cleaned_text.replace('\u2044', '/')
39
+ cleaned_text = cleaned_text.replace('\u2e17', '\u201e')
40
+ pattern = r'[\u2013\u2014]'
41
+ cleaned_text = re.sub(pattern, '-', cleaned_text)
42
+ pattern = r'[●•\xb7]'
43
+ cleaned_text = re.sub(pattern, '.', cleaned_text)
44
+ return cleaned_text
45
+
46
+
47
+ def get_char_sets():
48
+ english_lower = range(ord('a'), ord('z')+1)
49
+ english_upper = range(ord('A'), ord('Z')+1)
50
+
51
+ english_numbers = range(ord('0'), ord('9')+1)
52
+
53
+ english_ord = set(english_lower).union(english_upper)
54
+ english_numbers = {chr(c) for c in set(english_numbers)}
55
+ english_alphabet = {chr(c) for c in english_ord}
56
+
57
+ # This includes numerals/digits also
58
+ arabic_unicodes = range(ord("\u0600"), ord("\u06ff")+1)
59
+ arabic_ord = set(arabic_unicodes)
60
+ arabic_chars = {chr(c) for c in arabic_ord}
61
+ arabic_numbers_ord = range(ord("\u0660"), ord("\u0669")+1)
62
+ arabic_digits = {chr(c) for c in arabic_numbers_ord}
63
+ return {'english_alphabet': english_alphabet, 'arabic_unicodes': arabic_chars,
64
+ 'latin_digits': english_numbers, 'arabic_digits': arabic_digits}
65
+
66
+
67
+ def get_clean_visual_order(text):
68
+ charset_dict = get_char_sets()
69
+ text_set = set(text)
70
+ has_english_alphabet = len(text_set.intersection(charset_dict['english_alphabet'])) > 0
71
+ has_latin_digits = len(text_set.intersection(charset_dict['latin_digits'])) > 0
72
+ has_arabic_digits = len(text_set.intersection(charset_dict['arabic_digits'])) > 0
73
+
74
+
75
+
76
+ if has_arabic_digits or has_english_alphabet or has_latin_digits:
77
+ text_visual_order = get_display(text, base_dir='R')[::-1]
78
+ text_visual_order = correct_brackets(text_visual_order)
79
+ else:
80
+ text_visual_order = text
81
+ clean_visual_order = clean_text(text_visual_order)
82
+ return clean_visual_order
coords/text_gt.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Extract text from json file
2
+ # As all lines need sorting etc. this is being added to coords folder
3
+
4
+ import sys
5
+ import points
6
+ import json
7
+ import pandas as pd
8
+ import os
9
+ import numpy as np
10
+
11
+
12
+ def sort_lines(lines_list):
13
+
14
+ line_starts = [line['baseline'][0] for line in lines_list]
15
+ sorted_starts = sorted(enumerate(line_starts), key=lambda x: (x[1][1], -x[1][0]))
16
+ sorted_start_ind = [x[0] for x in sorted_starts]
17
+ sorted_lines = [lines_list[i] for i in sorted_start_ind]
18
+
19
+ return sorted_lines
20
+
21
+ def is_valid_key(key, json_obj):
22
+ if not key.lower().startswith('line_'):
23
+ return False
24
+ if "deleted" in json_obj[key].keys() and json_obj[key]["deleted"] != "0":
25
+ return False
26
+ # No text field
27
+ if not 'text' in json_obj[key]:
28
+ return False
29
+ # Text empty
30
+ json_obj[key]['text'] = json_obj[key]['text'].replace('\t', ' ')
31
+ if json_obj[key]['text'].strip() == "":
32
+ return False
33
+ return True
34
+
35
+ def get_text(json_file, return_list=False):
36
+ with open(json_file) as fin:
37
+ json_obj = json.load(fin)
38
+ to_remove_ind = []
39
+ # Get keys in json_obj
40
+ keys = list(json_obj.keys())
41
+ # Get list of line objects
42
+ lines = [json_obj[k] for k in keys if is_valid_key(k, json_obj)]
43
+ # Get baseline of each line
44
+ for ind, line in enumerate(lines):
45
+ poly_pts = line["coord"]
46
+ poly_pts = points.list_to_xy(poly_pts)
47
+ if len(poly_pts) <= 2:
48
+ print(json_file, len(poly_pts))
49
+ to_remove_ind.append(ind)
50
+ if not points.valid_poly(poly_pts):
51
+ to_remove_ind.append(ind)
52
+ continue
53
+ try:
54
+ baseline = points.get_baseline_chunks(poly_pts)
55
+ baseline.sort(key=lambda x: x[0], reverse=True)
56
+ line['baseline'] = baseline
57
+ except Exception as e:
58
+ #print(len(poly_pts))
59
+ #print(poly_pts)
60
+ #print(json_file)
61
+ to_remove_ind.append(ind)
62
+
63
+ # REmove the lines causing exception
64
+ cleaned_lines = [lines[ind] for ind in range(len(lines)) if not ind in to_remove_ind]
65
+ # Sort the lines
66
+ sorted_lines = sort_lines(cleaned_lines)
67
+
68
+ text = []
69
+ for l in sorted_lines:
70
+ text.append(l["text"])
71
+ if return_list:
72
+ return text
73
+ return '\n'.join(text)
74
+
75
+ def get_json_file(img_fullname):
76
+ dir, img_name = os.path.split(img_fullname)
77
+ json_files = []
78
+ #annotators = []
79
+ base_file = img_name[:-4]
80
+ files = os.listdir(dir)
81
+ for f in files:
82
+ prefix = base_file + '_annotate_'
83
+ if f.startswith(prefix):
84
+
85
+ # Check if its a timestamp in filename
86
+ partial_string = f[len(prefix):]
87
+ ind1 = partial_string.rfind('.')
88
+ ind2 = partial_string.find('.')
89
+
90
+
91
+ if (ind1 == ind2):
92
+ json_files.append(f)
93
+
94
+ if len(json_files) > 1:
95
+ print('More than one json found...returning 0th one', json_files)
96
+ if len(json_files) == 0:
97
+ print('No json found')
98
+ return None
99
+
100
+ return os.path.join(dir, json_files[0])
101
+
model/trial_26_A/muharaf_charset.json ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "char_to_idx": {
3
+ " ": 1,
4
+ "!": 2,
5
+ "\"": 3,
6
+ "#": 4,
7
+ "$": 5,
8
+ "%": 6,
9
+ "&": 7,
10
+ "'": 8,
11
+ "(": 9,
12
+ ")": 10,
13
+ "*": 11,
14
+ "+": 12,
15
+ ",": 13,
16
+ "-": 14,
17
+ ".": 15,
18
+ "/": 16,
19
+ "0": 17,
20
+ "1": 18,
21
+ "2": 19,
22
+ "3": 20,
23
+ "4": 21,
24
+ "5": 22,
25
+ "6": 23,
26
+ "7": 24,
27
+ "8": 25,
28
+ "9": 26,
29
+ ":": 27,
30
+ "=": 28,
31
+ "A": 29,
32
+ "B": 30,
33
+ "C": 31,
34
+ "D": 32,
35
+ "E": 33,
36
+ "F": 34,
37
+ "G": 35,
38
+ "H": 36,
39
+ "I": 37,
40
+ "J": 38,
41
+ "K": 39,
42
+ "L": 40,
43
+ "M": 41,
44
+ "N": 42,
45
+ "O": 43,
46
+ "P": 44,
47
+ "Q": 45,
48
+ "R": 46,
49
+ "S": 47,
50
+ "T": 48,
51
+ "U": 49,
52
+ "V": 50,
53
+ "W": 51,
54
+ "X": 52,
55
+ "Y": 53,
56
+ "Z": 54,
57
+ "[": 55,
58
+ "\\": 56,
59
+ "]": 57,
60
+ "_": 58,
61
+ "a": 59,
62
+ "b": 60,
63
+ "c": 61,
64
+ "d": 62,
65
+ "e": 63,
66
+ "f": 64,
67
+ "g": 65,
68
+ "h": 66,
69
+ "i": 67,
70
+ "j": 68,
71
+ "k": 69,
72
+ "l": 70,
73
+ "m": 71,
74
+ "n": 72,
75
+ "o": 73,
76
+ "p": 74,
77
+ "q": 75,
78
+ "r": 76,
79
+ "s": 77,
80
+ "t": 78,
81
+ "u": 79,
82
+ "v": 80,
83
+ "x": 81,
84
+ "y": 82,
85
+ "z": 83,
86
+ "|": 84,
87
+ "\u00ba": 85,
88
+ "\u00c3": 86,
89
+ "\u00c8": 87,
90
+ "\u00c9": 88,
91
+ "\u00ca": 89,
92
+ "\u00e0": 90,
93
+ "\u00e7": 91,
94
+ "\u00e8": 92,
95
+ "\u00e9": 93,
96
+ "\u00ea": 94,
97
+ "\u060c": 95,
98
+ "\u061b": 96,
99
+ "\u061f": 97,
100
+ "\u0621": 98,
101
+ "\u0622": 99,
102
+ "\u0623": 100,
103
+ "\u0624": 101,
104
+ "\u0625": 102,
105
+ "\u0626": 103,
106
+ "\u0627": 104,
107
+ "\u0628": 105,
108
+ "\u0629": 106,
109
+ "\u062a": 107,
110
+ "\u062b": 108,
111
+ "\u062c": 109,
112
+ "\u062d": 110,
113
+ "\u062e": 111,
114
+ "\u062f": 112,
115
+ "\u0630": 113,
116
+ "\u0631": 114,
117
+ "\u0632": 115,
118
+ "\u0633": 116,
119
+ "\u0634": 117,
120
+ "\u0635": 118,
121
+ "\u0636": 119,
122
+ "\u0637": 120,
123
+ "\u0638": 121,
124
+ "\u0639": 122,
125
+ "\u063a": 123,
126
+ "\u0640": 124,
127
+ "\u0641": 125,
128
+ "\u0642": 126,
129
+ "\u0643": 127,
130
+ "\u0644": 128,
131
+ "\u0645": 129,
132
+ "\u0646": 130,
133
+ "\u0647": 131,
134
+ "\u0648": 132,
135
+ "\u0649": 133,
136
+ "\u064a": 134,
137
+ "\u064b": 135,
138
+ "\u064c": 136,
139
+ "\u064d": 137,
140
+ "\u064e": 138,
141
+ "\u064f": 139,
142
+ "\u0650": 140,
143
+ "\u0651": 141,
144
+ "\u0652": 142,
145
+ "\u0660": 143,
146
+ "\u0661": 144,
147
+ "\u0662": 145,
148
+ "\u0663": 146,
149
+ "\u0664": 147,
150
+ "\u0665": 148,
151
+ "\u0666": 149,
152
+ "\u0667": 150,
153
+ "\u0668": 151,
154
+ "\u0669": 152,
155
+ "\u06a4": 153,
156
+ "\u06a8": 154,
157
+ "\u201e": 155,
158
+ "\ufb6c": 156,
159
+ "\ufc63": 157
160
+ },
161
+ "idx_to_char": {
162
+ "1": " ",
163
+ "2": "!",
164
+ "3": "\"",
165
+ "4": "#",
166
+ "5": "$",
167
+ "6": "%",
168
+ "7": "&",
169
+ "8": "'",
170
+ "9": "(",
171
+ "10": ")",
172
+ "11": "*",
173
+ "12": "+",
174
+ "13": ",",
175
+ "14": "-",
176
+ "15": ".",
177
+ "16": "/",
178
+ "17": "0",
179
+ "18": "1",
180
+ "19": "2",
181
+ "20": "3",
182
+ "21": "4",
183
+ "22": "5",
184
+ "23": "6",
185
+ "24": "7",
186
+ "25": "8",
187
+ "26": "9",
188
+ "27": ":",
189
+ "28": "=",
190
+ "29": "A",
191
+ "30": "B",
192
+ "31": "C",
193
+ "32": "D",
194
+ "33": "E",
195
+ "34": "F",
196
+ "35": "G",
197
+ "36": "H",
198
+ "37": "I",
199
+ "38": "J",
200
+ "39": "K",
201
+ "40": "L",
202
+ "41": "M",
203
+ "42": "N",
204
+ "43": "O",
205
+ "44": "P",
206
+ "45": "Q",
207
+ "46": "R",
208
+ "47": "S",
209
+ "48": "T",
210
+ "49": "U",
211
+ "50": "V",
212
+ "51": "W",
213
+ "52": "X",
214
+ "53": "Y",
215
+ "54": "Z",
216
+ "55": "[",
217
+ "56": "\\",
218
+ "57": "]",
219
+ "58": "_",
220
+ "59": "a",
221
+ "60": "b",
222
+ "61": "c",
223
+ "62": "d",
224
+ "63": "e",
225
+ "64": "f",
226
+ "65": "g",
227
+ "66": "h",
228
+ "67": "i",
229
+ "68": "j",
230
+ "69": "k",
231
+ "70": "l",
232
+ "71": "m",
233
+ "72": "n",
234
+ "73": "o",
235
+ "74": "p",
236
+ "75": "q",
237
+ "76": "r",
238
+ "77": "s",
239
+ "78": "t",
240
+ "79": "u",
241
+ "80": "v",
242
+ "81": "x",
243
+ "82": "y",
244
+ "83": "z",
245
+ "84": "|",
246
+ "85": "\u00ba",
247
+ "86": "\u00c3",
248
+ "87": "\u00c8",
249
+ "88": "\u00c9",
250
+ "89": "\u00ca",
251
+ "90": "\u00e0",
252
+ "91": "\u00e7",
253
+ "92": "\u00e8",
254
+ "93": "\u00e9",
255
+ "94": "\u00ea",
256
+ "95": "\u060c",
257
+ "96": "\u061b",
258
+ "97": "\u061f",
259
+ "98": "\u0621",
260
+ "99": "\u0622",
261
+ "100": "\u0623",
262
+ "101": "\u0624",
263
+ "102": "\u0625",
264
+ "103": "\u0626",
265
+ "104": "\u0627",
266
+ "105": "\u0628",
267
+ "106": "\u0629",
268
+ "107": "\u062a",
269
+ "108": "\u062b",
270
+ "109": "\u062c",
271
+ "110": "\u062d",
272
+ "111": "\u062e",
273
+ "112": "\u062f",
274
+ "113": "\u0630",
275
+ "114": "\u0631",
276
+ "115": "\u0632",
277
+ "116": "\u0633",
278
+ "117": "\u0634",
279
+ "118": "\u0635",
280
+ "119": "\u0636",
281
+ "120": "\u0637",
282
+ "121": "\u0638",
283
+ "122": "\u0639",
284
+ "123": "\u063a",
285
+ "124": "\u0640",
286
+ "125": "\u0641",
287
+ "126": "\u0642",
288
+ "127": "\u0643",
289
+ "128": "\u0644",
290
+ "129": "\u0645",
291
+ "130": "\u0646",
292
+ "131": "\u0647",
293
+ "132": "\u0648",
294
+ "133": "\u0649",
295
+ "134": "\u064a",
296
+ "135": "\u064b",
297
+ "136": "\u064c",
298
+ "137": "\u064d",
299
+ "138": "\u064e",
300
+ "139": "\u064f",
301
+ "140": "\u0650",
302
+ "141": "\u0651",
303
+ "142": "\u0652",
304
+ "143": "\u0660",
305
+ "144": "\u0661",
306
+ "145": "\u0662",
307
+ "146": "\u0663",
308
+ "147": "\u0664",
309
+ "148": "\u0665",
310
+ "149": "\u0666",
311
+ "150": "\u0667",
312
+ "151": "\u0668",
313
+ "152": "\u0669",
314
+ "153": "\u06a4",
315
+ "154": "\u06a8",
316
+ "155": "\u201e",
317
+ "156": "\ufb6c",
318
+ "157": "\ufc63"
319
+ }
320
+ }
model/trial_26_A/set0/config_2600.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ snapshot_path: model/trial_26_A/set0/pretrain/
2
+
3
+
4
+ network:
5
+ hw:
6
+ char_set_path: model/trial_26_A/muharaf_charset.json
7
+ cnn_out_size: 1024
8
+ input_height: 60
9
+ num_of_channels: 3
10
+ num_of_outputs: 158
11
+ use_instance_norm: true
12
+ lf:
13
+ look_ahead_matrix: null
14
+ step_bias: null
15
+
16
+ sol:
17
+ base0: 16
18
+ base1: 16
19
+ post_processing:
20
+ lf_nms_range:
21
+ - 0
22
+ - 6
23
+ lf_nms_threshold: 0.5
24
+ sol_threshold: 0.1
25
+
model/trial_26_A/set0/pretrain/hw.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e1f8a27a6fc5e64c1c93b72b18200db7b6c240053d3e016ad138b4cd3521b08
3
+ size 73251954
model/trial_26_A/set0/pretrain/lf.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1b7082953c1269b67ba0801311604884255a4fe4bdef821dcf1ce169ab582d6f
3
+ size 22228762
model/trial_26_A/set0/pretrain/sol.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e94e56959c5175067734e00d7e2e02ff05018f37eb91d726f2319504df60c832
3
+ size 36979951
py3/e2e/__init__.py ADDED
File without changes
py3/e2e/alignment_dataset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import Dataset
2
+ import json
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ import random
8
+ from utils import safe_load
9
+
10
+ def collate(batch):
11
+ return batch
12
+
13
+ class AlignmentDataset(Dataset):
14
+
15
+ def __init__(self, set_list, data_range=None, ignore_json=False, resize_width=512):
16
+
17
+ self.ignore_json = ignore_json
18
+
19
+ self.resize_width = resize_width
20
+
21
+ self.ids = set_list
22
+ self.ids.sort()
23
+
24
+ if data_range is not None:
25
+ self.ids = random.sample(self.ids, data_range)
26
+
27
+ print("Alignment Ids Count:", len(self.ids))
28
+
29
+ def __len__(self):
30
+ return len(self.ids)
31
+
32
+ def __getitem__(self, idx):
33
+
34
+ gt_json_path, img_path = self.ids[idx]
35
+
36
+ gt_json = []
37
+ if not self.ignore_json:
38
+ gt_json = safe_load.json_state(gt_json_path)
39
+ if gt_json is None:
40
+ return None
41
+
42
+ org_img = cv2.imread(img_path)
43
+
44
+ full_img = org_img.astype(np.float32)
45
+ full_img = full_img.transpose([2,1,0])[None,...]
46
+ full_img = torch.from_numpy(full_img)
47
+ full_img = full_img / 128 - 1
48
+
49
+ target_dim1 = self.resize_width
50
+ s = target_dim1 / float(org_img.shape[1])
51
+ target_dim0 = int(org_img.shape[0]/float(org_img.shape[1]) * target_dim1)
52
+
53
+ img = cv2.resize(org_img,(target_dim1, target_dim0), interpolation = cv2.INTER_CUBIC)
54
+ img = img.astype(np.float32)
55
+ img = img.transpose([2,1,0])[None,...]
56
+ img = torch.from_numpy(img)
57
+ img = img / 128 - 1
58
+
59
+ image_key = gt_json_path[:-len('.json')]
60
+
61
+ return {
62
+ "resized_img": img,
63
+ "full_img": full_img,
64
+ "resize_scale": 1.0/s,
65
+ "gt_lines": [x['gt'] for x in gt_json],
66
+ "img_key": image_key,
67
+ "json_path": gt_json_path,
68
+ "gt_json": gt_json
69
+ }
py3/e2e/e2e_model.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from utils import string_utils, error_rates
8
+ from utils import transformation_utils
9
+ from . import handwriting_alignment_loss
10
+
11
+ from . import e2e_postprocessing
12
+
13
+ import copy
14
+ from scipy.optimize import linear_sum_assignment
15
+ import math
16
+ #from pynvml import *
17
+
18
+
19
+
20
+ # max_lines_per_image is the max lines in a batch for HW to process
21
+ class E2EModel(nn.Module):
22
+ def __init__(self, sol, lf, hw, dtype=torch.cuda.FloatTensor, max_lines_per_image=8, device="cuda"):
23
+ super(E2EModel, self).__init__()
24
+
25
+ self.dtype = dtype
26
+
27
+ self.sol = sol
28
+ self.lf = lf
29
+ self.hw = hw
30
+ self.line = None
31
+ self.max_lines_per_image = max_lines_per_image
32
+
33
+ self.device=device
34
+
35
+ def train(self):
36
+ self.sol.train()
37
+ self.lf.train()
38
+ self.hw.train()
39
+
40
+ def eval(self):
41
+ self.sol.eval()
42
+ self.lf.eval()
43
+ self.hw.eval()
44
+
45
+ def forward(self, x, use_full_img=True, accpet_threshold=0.1, volatile=True, gt_lines=None,
46
+ idx_to_char=None, HW_cuda=0, device="cuda"):
47
+
48
+ if device != self.device:
49
+ print('Wrong device is set', 'param', device, 'self', self.device)
50
+ asldjfdkfj
51
+
52
+ sol_img = Variable(x['resized_img'].type(self.dtype), requires_grad=False)
53
+
54
+ if use_full_img:
55
+ img = Variable(x['full_img'].type(self.dtype), requires_grad=False)
56
+ scale = x['resize_scale']
57
+ results_scale = 1.0
58
+ else:
59
+ img = sol_img
60
+ scale = 1.0
61
+ results_scale = x['resize_scale']
62
+
63
+ original_starts = self.sol(sol_img)
64
+
65
+ start = original_starts
66
+
67
+ #Take at least one point
68
+ sorted_start, sorted_indices = torch.sort(start[...,0:1], dim=1, descending=True)
69
+ #print("sorted_start size", sorted_start.size())
70
+ #print("sorted_start", sorted_start)
71
+ min_threshold = sorted_start[0,1,0].data
72
+ accpet_threshold = min(accpet_threshold, min_threshold)
73
+ # There should not be more than 56 points to avoid out of memory
74
+ if sorted_start.size()[1] > 56:
75
+ accpet_threshold = max(accpet_threshold, sorted_start[0,55,0].data)
76
+ #print('using accept_threshold', accpet_threshold, sorted_start[0,55,0].data)
77
+ select = original_starts[...,0:1] >= accpet_threshold
78
+
79
+ select_idx = np.where(select.data.cpu().numpy())[1]
80
+
81
+ select = select.expand(select.size(0), select.size(1), start.size(2))
82
+ select = select.detach()
83
+ start = start[select].view(start.size(0), -1, start.size(2))
84
+
85
+ perform_forward = len(start.size()) == 3
86
+
87
+ if not perform_forward:
88
+ return None
89
+
90
+ forward_img = img
91
+
92
+ start = start.transpose(0,1)
93
+
94
+ positions = torch.cat([
95
+ start[...,1:3] * scale,
96
+ start[...,3:4],
97
+ start[...,4:5] * scale,
98
+ start[...,0:1]
99
+ ], 2)
100
+
101
+ #print('positions size', positions.size())
102
+ hw_out = []
103
+ p_interval = positions.size(0)
104
+ lf_xy_positions = None
105
+ line_imgs = []
106
+ # show_mem_status(1, "before for in FORWARD")
107
+ for p in range(0,min(positions.size(0), np.inf), p_interval):
108
+ sub_positions = positions[p:p+p_interval,0,:]
109
+ sub_select_idx = select_idx[p:p+p_interval]
110
+
111
+ batch_size = sub_positions.size(0)
112
+ sub_positions = [sub_positions]
113
+ # print(sub_positions)
114
+ # sys.exit()
115
+
116
+ expand_img = forward_img.expand(sub_positions[0].size(0), img.size(1), img.size(2), img.size(3))
117
+
118
+ step_size = 8 #5
119
+ extra_bw = 1 #1
120
+ forward_steps = 30 #40
121
+
122
+ grid_line, _, out_positions, xy_positions = self.lf(expand_img, sub_positions, steps=step_size)
123
+ grid_line, _, out_positions, xy_positions = self.lf(expand_img, [out_positions[step_size]], steps=step_size+extra_bw, negate_lw=True)
124
+ grid_line, _, out_positions, xy_positions = self.lf(expand_img, [out_positions[step_size+extra_bw]], steps=forward_steps, allow_end_early=True)
125
+
126
+ #show_mem_status(1, 'after lf')
127
+
128
+ if lf_xy_positions is None:
129
+ lf_xy_positions = xy_positions
130
+ else:
131
+ for i in range(len(lf_xy_positions)):
132
+ lf_xy_positions[i] = torch.cat([
133
+ lf_xy_positions[i],
134
+ xy_positions[i]
135
+ ])
136
+ expand_img = expand_img.transpose(2,3)
137
+
138
+ hw_interval = p_interval
139
+ for h in range(0,min(grid_line.size(0), np.inf), hw_interval):
140
+ sub_out_positions = [o[h:h+hw_interval] for o in out_positions]
141
+ sub_xy_positions = [o[h:h+hw_interval] for o in xy_positions]
142
+ sub_sub_select_idx = sub_select_idx[h:h+hw_interval]
143
+
144
+ line = torch.nn.functional.grid_sample(expand_img[h:h+hw_interval].detach(), grid_line[h:h+hw_interval], align_corners=True)
145
+ line = line.transpose(2,3)
146
+
147
+ for l in line:
148
+ l = l.transpose(0,1).transpose(1,2)
149
+ l = (l + 1)*128
150
+ l_np = l.data.cpu().numpy()
151
+ line_imgs.append(l_np)
152
+ # cv2.imwrite("example_line_out.png", l_np)
153
+ # print "Saved!"
154
+ # raw_input()
155
+
156
+ # REsize to 60 ht
157
+
158
+ # Mehreen add: To avoid out of memory errors. A large batch has to be split up for HW network to process
159
+ # This case will arise when SOL finds too many lines on a page
160
+ batch, channels, old_ht, old_width = line.size()
161
+ line = line.detach().cpu()
162
+ total_todo = batch
163
+ #show_mem_status(0, '.... Before hw line')
164
+
165
+ start_index = 0
166
+ while total_todo > 0:
167
+ mini_batch_size = min(self.max_lines_per_image, total_todo)
168
+ partial_lines = line[start_index:start_index+mini_batch_size, :, :, :]
169
+ #print('start_index, end_index', start_index, start_index+mini_batch_size)
170
+ start_index += mini_batch_size
171
+ total_todo = total_todo - mini_batch_size
172
+ #print('partial_line size', partial_lines.size())
173
+ partial_lines = partial_lines.to(self.device)
174
+ out = self.hw(partial_lines)
175
+ if "cuda" in device:
176
+ torch.cuda.empty_cache()
177
+ out = out.transpose(0, 1)
178
+ hw_out.append(out)
179
+
180
+ #print('batch size: ', batch)
181
+ # new_ht = 60
182
+ # new_width = int(old_width/old_ht*new_ht)
183
+ #print('line type', type(line), line.size())
184
+ # self.line = nn.functional.interpolate(line, size=(new_ht, new_width),
185
+ # mode='bilinear', align_corners=True)
186
+ # Mehreen commented out for processing entire batch in one go
187
+
188
+ # out = self.hw(line)
189
+ # out = out.transpose(0,1)
190
+
191
+ # hw_out.append(out)
192
+ #show_mem_status(0, '.... After hw line')
193
+
194
+
195
+
196
+ hw_out = torch.cat(hw_out, 0)
197
+ # print(original_starts,positions,lf_xy_positions,hw_out,results_scale,line_imgs)
198
+
199
+
200
+ return {
201
+ "original_sol": original_starts,
202
+ "sol": positions,
203
+ "lf": lf_xy_positions,
204
+ "hw": hw_out,
205
+ "results_scale": results_scale,
206
+ "line_imgs": line_imgs
207
+ }
py3/e2e/e2e_postprocessing.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import string_utils, error_rates
2
+ import numpy as np
3
+ from . import nms
4
+ import copy
5
+
6
+ def get_trimmed_polygons(out):
7
+ all_polygons = []
8
+ for j in range(out['lf'][0].shape[0]):
9
+ begin = out['beginning'][j]
10
+ end = out['ending'][j]
11
+ last_xy = None
12
+ begin_f = int(np.floor(begin))
13
+ end_f = int(np.ceil(end))
14
+ points = []
15
+ for i in range(begin_f, end_f+1):
16
+
17
+ if i == begin_f:
18
+ p0 = out['lf'][i][j]
19
+ p1 = out['lf'][i+1][j]
20
+ t = begin - np.floor(begin)
21
+ p = p0 * (1 - t) + p1 * t
22
+
23
+ elif i == end_f:
24
+
25
+ p0 = out['lf'][i-1][j]
26
+ if i != len(out['lf']):
27
+ p1 = out['lf'][i][j]
28
+ t = end - np.floor(end)
29
+ p = p0 * (1 - t) + p1 * t
30
+ else:
31
+ p = p0
32
+ else:
33
+ p = out['lf'][i][j]
34
+
35
+ points.append(p)
36
+ points = np.array(points)
37
+ all_polygons.append(points)
38
+ return all_polygons
39
+
40
+ def trim_ends(out):
41
+
42
+ lf_length = len(out['lf'])
43
+ hw = out['hw']
44
+ # Mehreen: hw is (14, 361, 197) a 14x361 matrix for each character. selected is (14, 361)
45
+ selected = hw.argmax(axis=-1)
46
+ beginning = np.argmax(selected != 0, axis=1)
47
+ ending = selected.shape[1] - 1 - np.argmax(selected[:,::-1] != 0, axis=1)
48
+
49
+ beginning_percent = (beginning+0.5) / float(selected.shape[1])
50
+ ending_percent = (ending+0.5) / float(selected.shape[1])
51
+
52
+ lf_beginning = lf_length * beginning_percent
53
+ lf_ending = lf_length * ending_percent
54
+
55
+ out['beginning'] = lf_beginning
56
+ out['ending'] = lf_ending
57
+ return out
58
+
59
+ def filter_on_pick(out, pick):
60
+ out['sol'] = out['sol'][pick]
61
+ out['lf'] = [l[pick] for l in out['lf']]
62
+ out['hw'] = out['hw'][pick]
63
+
64
+ if 'idx' in out:
65
+ out['idx'] = out['idx'][pick]
66
+ if 'beginning' in out:
67
+ out['beginning'] = out['beginning'][pick]
68
+ if 'ending' in out:
69
+ out['ending'] = out['ending'][pick]
70
+
71
+ def filter_on_pick_no_copy(out, pick):
72
+ output = {}
73
+ output['sol'] = out['sol'][pick]
74
+ output['lf'] = [l[pick] for l in out['lf']]
75
+ output['hw'] = out['hw'][pick]
76
+ ##Mehreen
77
+ #print(pick)
78
+ #out['line_imgs'] = out['line_imgs'][pick]
79
+ ## End mehreen
80
+ if 'idx' in out:
81
+ output['idx'] = out['idx'][pick]
82
+ if 'beginning' in out:
83
+ output['beginning'] = out['beginning'][pick]
84
+ if 'ending' in out:
85
+ output['ending'] = out['ending'][pick]
86
+ return output
87
+
88
+ def select_non_empty_string(out):
89
+ selected = out['hw'].argmax(axis=-1)
90
+ return np.where(selected.sum(axis=1) != 0)
91
+
92
+ def postprocess(out, **kwargs):
93
+ out = copy.copy(out)
94
+
95
+ # postprocessing should be done with numpy data
96
+ sol_threshold = kwargs.get("sol_threshold", None)
97
+ sol_nms_threshold = kwargs.get("sol_nms_threshold", None)
98
+ lf_nms_params = kwargs.get('lf_nms_params', None)
99
+ lf_nms_2_params = kwargs.get('lf_nms_2_params', None)
100
+
101
+ if sol_threshold is not None:
102
+ pick = np.where(out['sol'][:,-1] > sol_threshold)
103
+ filter_on_pick(out, pick)
104
+
105
+ #Mehreen: this is passed as None from run_hwr from decode_one_img_with_info
106
+ if sol_nms_threshold is not None:
107
+ raise Exception("This is not correct")
108
+ pick = nms.sol_nms_single(out['sol'], sol_nms_threshold)
109
+ out['sol'] = out['sol'][pick]
110
+
111
+ #Mehreen: When post-processing this part is done. sample_config lf_nms_range: [0,6] lf_nms_threshold: 0.5
112
+ if lf_nms_params is not None:
113
+ confidences = out['sol'][:,-1]
114
+ overlap_range = lf_nms_params['overlap_range']
115
+ overlap_thresh = lf_nms_params['overlap_threshold']
116
+
117
+ lf_setup = np.concatenate([l[None,...] for l in out['lf']])
118
+ lf_setup = [lf_setup[:,i] for i in range(lf_setup.shape[1])]
119
+
120
+ pick = nms.lf_non_max_suppression_area(lf_setup, confidences, overlap_range, overlap_thresh)
121
+ filter_on_pick(out, pick)
122
+
123
+ #Mehreen: When post-processing this part is None from decode_one_img_with_info
124
+ if lf_nms_2_params is not None:
125
+ confidences = out['sol'][:,-1]
126
+ overlap_thresh = lf_nms_2_params['overlap_threshold']
127
+ refined_lf = get_trimmed_polygons(out)
128
+ pick = nms.lf_non_max_suppression_area(refined_lf, confidences, None, overlap_thresh)
129
+ filter_on_pick(out, pick)
130
+
131
+ return out
132
+
133
+ def read_order(out):
134
+ first_pt = out['lf'][0][:,:2,0]
135
+
136
+ first_pt = first_pt[:,::-1]
137
+ first_pt = np.concatenate([first_pt, np.arange(first_pt.shape[0])[:,None]], axis=1)
138
+ first_pt = first_pt.tolist()
139
+
140
+ first_pt.sort()
141
+
142
+ return [int(p[2]) for p in first_pt]
143
+
144
+ def decode_handwriting(out, idx_to_char):
145
+ hw_out = out['hw']
146
+ list_of_pred = []
147
+ list_of_raw_pred = []
148
+ for i in range(hw_out.shape[0]):
149
+ logits = hw_out[i,...]
150
+ pred, raw_pred = string_utils.naive_decode(logits)
151
+ pred_str = string_utils.label2str_single(pred, idx_to_char, False)
152
+ raw_pred_str = string_utils.label2str_single(raw_pred, idx_to_char, True)
153
+ list_of_pred.append(pred_str)
154
+ list_of_raw_pred.append(raw_pred_str)
155
+
156
+ return list_of_pred, list_of_raw_pred
157
+
158
+
159
+ def results_to_numpy(out):
160
+ return {
161
+ "sol": out['sol'].data.cpu().numpy()[:,0,:],
162
+ "lf": [l.data.cpu().numpy() for l in out['lf']] if out['lf'] is not None else None,
163
+ "hw": out['hw'].data.cpu().numpy(),
164
+ "results_scale": out['results_scale'],
165
+ "line_imgs": out['line_imgs'],
166
+ }
167
+
168
+ def align_to_gt_lines(decoded_hw, gt_lines):
169
+ costs = []
170
+ for i in range(len(decoded_hw)):
171
+ costs.append([])
172
+ for j in range(len(gt_lines)):
173
+ pred = decoded_hw[i]
174
+ gt = gt_lines[j]
175
+ cer = error_rates.cer(gt, pred)
176
+ costs[i].append(cer)
177
+
178
+ costs = np.array(costs)
179
+ min_idx = costs.argmin(axis=0)
180
+ min_val = costs.min(axis=0)
181
+
182
+ return min_idx, min_val
py3/e2e/forward_pass.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from e2e import e2e_model
2
+ from e2e.e2e_model import E2EModel
3
+
4
+ from . import validation_utils
5
+
6
+ from utils import error_rates
7
+
8
+ import itertools
9
+ import copy
10
+ import numpy as np
11
+ import cv2
12
+
13
+ def forward_pass(x, e2e, config, thresholds, idx_to_char, update_json=False):
14
+
15
+ gt_lines = x['gt_lines']
16
+ gt = "\n".join(gt_lines)
17
+
18
+ out_original = e2e(x)
19
+ results = {}
20
+ if out_original is None:
21
+ #TODO: not a good way to handle this, but fine for now
22
+ None
23
+
24
+ gt_lines = x['gt_lines']
25
+ gt = "\n".join(gt_lines)
26
+
27
+ out_original = E2EModel.results_to_numpy(out_original)
28
+ out_original['idx'] = np.arange(out_original['sol'].shape[0])
29
+
30
+ decoded_hw, decoded_raw_hw = E2EModel.decode_handwriting(out_original, idx_to_char)
31
+ pick, costs = E2EModel.align_to_gt_lines(decoded_hw, gt_lines)
32
+
33
+ most_ideal_pred_lines, improved_idxs = validation_utils.update_ideal_results(pick, costs, decoded_hw, x['gt_json'])
34
+ # if update_json:
35
+ # validation_utils.save_improved_idxs(improved_idxs, decoded_hw,
36
+ # decoded_raw_hw, out_original,
37
+ # x, config[dataset_lookup]['json_folder'], config['alignment']['trim_to_sol'])
38
+
39
+ sol_thresholds = thresholds[0]
40
+ sol_thresholds_idx = list(range(len(sol_thresholds)))
41
+
42
+ lf_nms_ranges = thresholds[1]
43
+ lf_nms_ranges_idx = list(range(len(lf_nms_ranges)))
44
+
45
+ lf_nms_thresholds = thresholds[2]
46
+ lf_nms_thresholds_idx = list(range(len(lf_nms_thresholds)))
47
+
48
+ most_ideal_pred_lines = "\n".join(most_ideal_pred_lines)
49
+
50
+ ideal_pred_lines = [decoded_hw[i] for i in pick]
51
+ ideal_pred_lines = "\n".join(ideal_pred_lines)
52
+
53
+ error = error_rates.cer(gt, ideal_pred_lines)
54
+ ideal_result = error
55
+
56
+ error = error_rates.cer(gt, most_ideal_pred_lines)
57
+ most_ideal_result = error
58
+
59
+ for key in itertools.product(sol_thresholds_idx, lf_nms_ranges_idx, lf_nms_thresholds_idx):
60
+ i,j,k = key
61
+ sol_threshold = sol_thresholds[i]
62
+ lf_nms_range = lf_nms_ranges[j]
63
+ lf_nms_threshold = lf_nms_thresholds[k]
64
+
65
+ out = copy.copy(out_original)
66
+
67
+ out = E2EModel.postprocess(out,
68
+ sol_threshold=sol_threshold,
69
+ lf_nms_params={
70
+ "overlap_range": lf_nms_range,
71
+ "overlap_threshold": lf_nms_threshold
72
+ })
73
+ order = E2EModel.read_order(out)
74
+ E2EModel.filter_on_pick(out, order)
75
+
76
+ # draw_img = E2EModel.draw_output(out, img)
77
+ # cv2.imwrite("test_b_samples/test_img_{}.png".format(a), draw_img)
78
+
79
+ preds = [decoded_hw[i] for i in out['idx']]
80
+ pred = "\n".join(preds)
81
+
82
+ error = error_rates.cer(gt, pred)
83
+
84
+ results[key] = error
85
+
86
+ return results, ideal_result, most_ideal_result
py3/e2e/handwriting_alignment_loss.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import string_utils, error_rates
2
+ import torch
3
+ from scipy.optimize import linear_sum_assignment
4
+ import numpy as np
5
+ from torch.autograd import Variable
6
+
7
+
8
+ def accumulate_scores(out, out_positions, xy_positions, gt_state, idx_to_char):
9
+
10
+
11
+ preds = out.transpose(0,1).cpu()
12
+ batch_size = preds.size(1)
13
+ preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
14
+
15
+ for i, logits in enumerate(out.data.cpu().numpy()):
16
+ raw_decode, raw_decode_full = string_utils.naive_decode(logits)
17
+ pred_str = string_utils.label2str_single(raw_decode, idx_to_char, False)
18
+ pred_str_full = string_utils.label2str_single(raw_decode_full, idx_to_char, True)
19
+
20
+
21
+ sub_out_positions = [o[i].data.cpu().numpy().tolist() for o in out_positions]
22
+ sub_xy_positions = [o[i].data.cpu().numpy().tolist() for o in xy_positions]
23
+
24
+ for gt_obj in gt_state:
25
+ gt_text = gt_obj['gt']
26
+ cer = error_rates.cer(gt_text, pred_str)
27
+
28
+ #This is a terrible way to do this...
29
+ gt_obj['errors'] = gt_obj.get('errors', [])
30
+ gt_obj['pred'] = gt_obj.get('pred', [])
31
+ gt_obj['pred_full'] = gt_obj.get('pred_full', [])
32
+ gt_obj['path'] = gt_obj.get('path', [])
33
+ gt_obj['path_xy'] = gt_obj.get('path_xy', [])
34
+
35
+ gt_obj['errors'].append(cer)
36
+ gt_obj['pred'].append(pred_str)
37
+ gt_obj['pred_full'].append(pred_str_full)
38
+ gt_obj['path'].append(sub_out_positions)
39
+ gt_obj['path_xy'].append(sub_xy_positions)
40
+
41
+
42
+ def update_alignment(out, gt_lines, alignments, idx_to_char, idx_mapping, sol_positions):
43
+
44
+ preds = out.cpu()
45
+ batch_size = preds.size(1)
46
+ preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
47
+
48
+ for i, logits in enumerate(out.data.cpu().numpy()):
49
+ raw_decode, raw_decode_full = string_utils.naive_decode(logits)
50
+ pred_str = string_utils.label2str_single(raw_decode, idx_to_char, False)
51
+
52
+ for j, gt in enumerate(gt_lines):
53
+ cer = error_rates.cer(gt, pred_str)
54
+ global_i = idx_mapping[i]
55
+ c = sol_positions[i,0,-1].data[0]
56
+
57
+ # alignment_error = cer
58
+ alignment_error = cer + 0.1 * (1.0 - c)
59
+
60
+ if alignment_error < alignments[j][0]:
61
+ alignments[j][0] = alignment_error
62
+ alignments[j][1] = global_i
63
+ # alignments[j][2] = out[i][:,None,:]
64
+ alignments[j][2] = None
65
+ alignments[j][3] = pred_str
66
+
67
+ def alignment(predictions, hw_scores, alpha_alignment=0.1, alpha_backprop=0.1):
68
+ confidences = predictions[:,:,4]
69
+
70
+ log_confidences = torch.log(confidences + 1e-10)
71
+ log_one_minus_confidences = torch.log(1.0 - confidences + 1e-10)
72
+
73
+ expanded_log_confidences = log_confidences[:,:,None].expand(confidences.size(0), confidences.size(1), hw_scores.size(2))
74
+ expanded_log_one_minus_confidences = log_one_minus_confidences[:,:,None].expand(confidences.size(0), confidences.size(1), hw_scores.size(2))
75
+
76
+ C = alpha_alignment * hw_scores - expanded_log_confidences + expanded_log_one_minus_confidences
77
+
78
+ C = C.data.cpu().numpy()
79
+ X = np.zeros_like(C)
80
+
81
+ idxs = []
82
+ for b in range(C.shape[0]):
83
+ C_i = C[b]
84
+ row_ind, col_ind = linear_sum_assignment(C_i.T)
85
+ idxs.append((col_ind, row_ind))
86
+
87
+ return idxs
88
+
89
+ def loss(preds, non_hw_sol, hw_sol, gt_lines, char_to_idx, criterion):
90
+ label_lengths = []
91
+ all_labels = []
92
+ for gt_str in gt_lines:
93
+ l = string_utils.str2label_single(gt_str, char_to_idx)
94
+ all_labels.append(l)
95
+ label_lengths.append(len(l))
96
+
97
+ all_labels = np.concatenate(all_labels)
98
+ label_lengths = np.array(label_lengths)
99
+
100
+ labels = torch.from_numpy(all_labels.astype(np.int32))
101
+ label_lengths = torch.from_numpy(label_lengths.astype(np.int32))
102
+
103
+ labels = Variable(labels, requires_grad=False)
104
+ label_lengths = Variable(label_lengths, requires_grad=False)
105
+
106
+ batch_size = preds.size(0)
107
+ preds_size = Variable(torch.IntTensor([preds.size(0)] * batch_size))
108
+
109
+ ctc_loss = 1e-2 * criterion(preds.cpu(), labels, preds_size, label_lengths)
110
+
111
+ log_one_minus_confidences = torch.log(1.0 - non_hw_sol[:,:,0] + 1e-10)
112
+ log_confidences = torch.log(hw_sol[:,:,0] + 1e-10)
113
+
114
+ selected_confidence = log_confidences.sum()
115
+ not_selected_confidence = log_one_minus_confidences.sum()
116
+
117
+ confidence_loss = -selected_confidence - not_selected_confidence
118
+
119
+ # print " - - - - Losses - - - - "
120
+ # print ctc_loss.data[0]
121
+ # print selected_confidence.data[0], log_confidences.size()
122
+ # print not_selected_confidence.data[0], log_one_minus_confidences.size()
123
+ # print ""
124
+
125
+ return ctc_loss + confidence_loss.cpu()
py3/e2e/nms.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pyclipper
4
+
5
+ def sol_non_max_suppression(start_torch, overlap_thresh):
6
+ #Todo: Make this work with batches
7
+
8
+ #Rotation is not taken into account
9
+ start = start_torch.data.cpu().numpy()
10
+
11
+ pick = sol_nms_single(start[0], overlap_thresh)
12
+
13
+ zero_idx = [0 for _ in range(len(pick))]
14
+
15
+ select = (zero_idx, pick)
16
+ return start_torch[select][None,...]
17
+
18
+
19
+ def sol_nms_single(start, overlap_thresh):
20
+ # Based on https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/
21
+ # Maybe could port to pytorch to work over the tensors directly
22
+
23
+ # Mehreen comment: mapping start[:,1] is x, start[:,2] is y, start[:,3] is x and start[:,4] is theta
24
+ # So x1,y1 is top left corner and x2,y2 is bottom right corner
25
+ x1 = start[:,1] - start[:,3]
26
+ y1 = start[:,2] - start[:,3]
27
+
28
+ x2 = start[:,1] + start[:,3]
29
+ y2 = start[:,2] + start[:,3]
30
+
31
+ c = start[:,0]
32
+
33
+ area = (x2 - x1 + 1) * (y2 - y1 + 1)
34
+ idxs = np.argsort(c)
35
+
36
+ pick = []
37
+ while len(idxs) > 0:
38
+
39
+ last = len(idxs) - 1
40
+ i = idxs[last]
41
+ pick.append(i)
42
+
43
+ xx1 = np.maximum(x1[i], x1[idxs[:last]])
44
+ yy1 = np.maximum(y1[i], y1[idxs[:last]])
45
+ xx2 = np.minimum(x2[i], x2[idxs[:last]])
46
+ yy2 = np.minimum(y2[i], y2[idxs[:last]])
47
+
48
+ w = np.maximum(0, xx2 - xx1 + 1)
49
+ h = np.maximum(0, yy2 - yy1 + 1)
50
+
51
+ overlap = (w * h) / area[idxs[:last]]
52
+
53
+ idxs = np.delete(idxs, np.concatenate(([last],
54
+ np.where(overlap > overlap_thresh)[0])))
55
+ return pick
56
+
57
+ def lf_non_max_suppression_area(lf_xy_positions, confidences, overlap_range, overlap_thresh):
58
+ # lf_xy_positions = np.concatenate([l.data.cpu().numpy()[None,...] for l in lf_xy_positions])
59
+ # lf_xy_positions = lf_xy_positions[:,:,:2,:2]
60
+
61
+ # print lf_xy_positions
62
+ # raw_input()
63
+ lf_xy_positions = [l[:,:2,:2] for l in lf_xy_positions]
64
+ #this assumes equal length positions
65
+ # lf_xy_positions = np.concatenate([l[None,...] for l in lf_xy_positions])
66
+ # lf_xy_positions = lf_xy_positions[:,:,:2,:2]
67
+
68
+ c = confidences
69
+
70
+
71
+ bboxes = []
72
+ center_lines = []
73
+ scales = []
74
+ for i in range(len(lf_xy_positions)):
75
+ pts = lf_xy_positions[i]
76
+ # for i in xrange(lf_xy_positions.shape[1]):
77
+ # pts = lf_xy_positions[:,i,:]
78
+ if overlap_range is not None:
79
+ pts = pts[overlap_range[0]: overlap_range[1]]
80
+
81
+ f = pts[0]
82
+ delta = f[:,0] - f[:,1]
83
+ scale = np.sqrt( (delta**2).sum() )
84
+ scales.append(scale)
85
+
86
+ # ls = pts[:,:,0].tolist() + pts[:,:,1][::-1].tolist()
87
+ # ls = [[int(x[0]), int(x[1])] for x in ls]
88
+ # poly_regions.append(ls)
89
+ center_lines.append( (pts[:,:,0] + pts[:,:,1])/2.0 )
90
+
91
+ min_x = pts[:,0].min()
92
+ max_x = pts[:,0].max()
93
+ min_y = pts[:,1].min()
94
+ max_y = pts[:,1].max()
95
+
96
+ bboxes.append((min_x, min_y, max_x, max_y))
97
+
98
+ bboxes = np.array(bboxes)
99
+
100
+ if len(bboxes.shape) < 2:
101
+ return []
102
+
103
+ x1 = bboxes[:,0]
104
+ y1 = bboxes[:,1]
105
+ x2 = bboxes[:,2]
106
+ y2 = bboxes[:,3]
107
+
108
+ area = (x2 - x1 + 1) * (y2 - y1 + 1)
109
+ idxs = np.argsort(c)
110
+
111
+
112
+ overlapping_regions = []
113
+ pick = []
114
+ while len(idxs) > 0:
115
+
116
+ last = len(idxs) - 1
117
+ i = idxs[last]
118
+ pick.append(i)
119
+
120
+ xx1 = np.maximum(x1[i], x1[idxs[:last]])
121
+ yy1 = np.maximum(y1[i], y1[idxs[:last]])
122
+ xx2 = np.minimum(x2[i], x2[idxs[:last]])
123
+ yy2 = np.minimum(y2[i], y2[idxs[:last]])
124
+
125
+ # compute the width and height of the bounding box
126
+ w = np.maximum(0, xx2 - xx1 + 1)
127
+ h = np.maximum(0, yy2 - yy1 + 1)
128
+
129
+ # compute the ratio of overlap
130
+ overlap_bb = (w * h) / area[idxs[:last]]
131
+
132
+ overlap = []
133
+ for step, j in enumerate(idxs[:last]):
134
+ #Skip anything that does't actually have any overlap
135
+ if overlap_bb[step] < 0.1:
136
+ overlap.append(0)
137
+ continue
138
+
139
+ path0 = center_lines[i]
140
+ path1 = center_lines[j]
141
+
142
+ path = np.concatenate([path0, path1[::-1]])
143
+ path = [[int(x[0]), int(x[1])] for x in path]
144
+
145
+ expected_scale = (scales[i] + scales[j])/2.0
146
+ one_off_area = expected_scale**2 * (path0.shape[0] + path1.shape[0])/2.0
147
+
148
+ simple_path = pyclipper.SimplifyPolygon(path, pyclipper.PFT_NONZERO)
149
+ inter_area = 0
150
+ for path in simple_path:
151
+ inter_area += abs(pyclipper.Area(path))
152
+
153
+ area_ratio = inter_area / one_off_area
154
+ area_ratio = 1.0 - area_ratio
155
+
156
+ overlap.append(area_ratio)
157
+
158
+ overlap = np.array(overlap)
159
+ to_delete = np.concatenate(([last], np.where(overlap > overlap_thresh)[0]))
160
+ idxs = np.delete(idxs, to_delete)
161
+
162
+ return pick
py3/e2e/validation_utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import error_rates
2
+ import copy
3
+ import os
4
+ import cv2
5
+ import json
6
+
7
+ from copy import deepcopy
8
+
9
+ import numpy as np
10
+
11
+ def interpolate(key1, key2, lf, lf_idx, step_percent):
12
+ x0 = lf[lf_idx][key1]
13
+ y0 = lf[lf_idx][key2]
14
+ x1 = lf[lf_idx+1][key1]
15
+ y1 = lf[lf_idx+1][key2]
16
+
17
+ x = x1 * step_percent + x0 * (1.0 - step_percent)
18
+ y = y1 * step_percent + y0 * (1.0 - step_percent)
19
+
20
+ return x, y
21
+
22
+ def get_subdivide_pt(i, pred_full, lf):
23
+ percent = (float(i)+0.5) / float(len(pred_full))
24
+ lf_percent = (len(lf)-1) * percent
25
+
26
+ lf_idx = int(np.floor(lf_percent))
27
+ step_percent = lf_percent - lf_idx
28
+
29
+ x0, y0 = interpolate("x0", "y0", lf, lf_idx, step_percent)
30
+ x1, y1 = interpolate("x1", "y1", lf, lf_idx, step_percent)
31
+
32
+ return x0, y0, x1, y1
33
+
34
+ def save_improved_idxs(improved_idxs, decoded_hw, decoded_raw_hw, out, x, json_folder):
35
+
36
+ output_lines = [{
37
+ "gt": gt['gt']
38
+ } for gt in x['gt_json']]
39
+
40
+
41
+ # for i in improved_idxs:
42
+ for i in range(len(output_lines)):
43
+
44
+ if not i in improved_idxs:
45
+ output_lines[i] = x['gt_json'][i]
46
+ continue
47
+
48
+ k = improved_idxs[i]
49
+
50
+ # We want to trim the LF results
51
+ # good to keep around the full length of the prediciton
52
+ # so we can generate the full line-level images later
53
+ # at a different resolution
54
+ line_points = []
55
+ after_line_points = []
56
+ lf_path = out['lf']
57
+ end = out['ending'][k]
58
+ for j in range(len(lf_path)):
59
+ p = lf_path[j][k]
60
+ s = out['results_scale']
61
+
62
+ if j > end:
63
+ after_line_points.append({
64
+ "x0": p[0][1] * s,
65
+ "x1": p[0][0] * s,
66
+ "y0": p[1][1] * s,
67
+ "y1": p[1][0] * s
68
+ })
69
+ else:
70
+ line_points.append({
71
+ "x0": p[0][1] * s,
72
+ "x1": p[0][0] * s,
73
+ "y0": p[1][1] * s,
74
+ "y1": p[1][0] * s
75
+ })
76
+
77
+ begin = out['beginning'][k]
78
+ begin_f = int(np.floor(begin))
79
+ p0 = out['lf'][begin_f][k]
80
+ if begin_f+1 >= len(out['lf']):
81
+ p = p0
82
+ else:
83
+ p1 = out['lf'][begin_f+1][k]
84
+ t = begin - np.floor(begin)
85
+ p = p0 * (1 - t) + p1 * t
86
+
87
+ sol_point = {
88
+ "x0": p[0][1] * s,
89
+ "x1": p[0][0] * s,
90
+ "y0": p[1][1] * s,
91
+ "y1": p[1][0] * s
92
+ }
93
+
94
+ img_file_name = "{}_{}.png".format(x['img_key'], i)
95
+
96
+ output_lines[i]['pred'] = decoded_hw[k]
97
+ output_lines[i]['pred_full'] = decoded_raw_hw[k]
98
+ output_lines[i]['sol'] = sol_point
99
+ output_lines[i]['lf'] = line_points
100
+ output_lines[i]['after_lf'] = after_line_points
101
+ output_lines[i]['start_idx'] = 1 #TODO: update to backward idx
102
+ output_lines[i]['hw_path'] = img_file_name
103
+
104
+ line_img = out['line_imgs'][k]
105
+
106
+ full_img_file_name = os.path.join(json_folder, img_file_name)
107
+ cv2.imwrite(full_img_file_name, line_img)
108
+
109
+ json_path = x['json_path']
110
+ with open(json_path, 'w') as f:
111
+ # print('written data to:', f)
112
+ json.dump(output_lines, f)
113
+
114
+ def update_ideal_results(pick, costs, decoded_hw, gt_json):
115
+
116
+ most_ideal_pred = []
117
+ improved_idxs = {}
118
+
119
+ for i in range(len(gt_json)):
120
+ gt_obj = gt_json[i]
121
+
122
+ prev_pred = gt_obj.get('pred', '')
123
+ gt = gt_obj['gt']
124
+
125
+ pred = decoded_hw[pick[i]]
126
+
127
+ prev_cer = error_rates.cer(gt, prev_pred)
128
+ cer = costs[i]
129
+
130
+ if cer > prev_cer or len(pred) == 0:
131
+ most_ideal_pred.append(prev_pred)
132
+ continue
133
+
134
+ most_ideal_pred.append(pred)
135
+ improved_idxs[i] = pick[i]
136
+
137
+ return most_ideal_pred, improved_idxs
py3/e2e/visualization.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ def draw_output(out, img):
4
+ img = img.copy()
5
+
6
+ # print(out['lf'][0].shape[0], out['sol'].shape[0])
7
+ # sys.exit()
8
+
9
+ for i in range(out['sol'].shape[0]):
10
+ j=i
11
+
12
+ p = out['sol'][i]
13
+
14
+ c = int(255 * p[-1])
15
+ color = (c,0,255-c)
16
+
17
+ x = p[0]
18
+ y = p[1]
19
+ r = p[2]
20
+ x_comp = np.cos(r)
21
+ y_comp = -np.sin(r)
22
+ s = p[3]
23
+
24
+ rx = x + s * x_comp * 2
25
+ ry = y + s * y_comp * 2
26
+
27
+ rx2 = x - s * x_comp
28
+ ry2 = y - s * y_comp
29
+
30
+ rx = int(rx)
31
+ ry = int(ry)
32
+
33
+ rx2 = int(rx2)
34
+ ry2 = int(ry2)
35
+
36
+ x = int(x)
37
+ y = int(y)
38
+ scale = abs(int(s))
39
+
40
+ color = (0,0,255)
41
+
42
+ cv2.circle(img,(x,y), int(scale), color, 2)
43
+ cv2.circle(img,(x,y), 4, color, -1)
44
+ cv2.arrowedLine(img, (x,y), (rx,ry), color, 2, tipLength=0.25)
45
+ # cv2.line(img, (rx2,ry2), (rx,ry), color, 2)
46
+ cv2.putText(img,str(i),(x,y), cv2.FONT_HERSHEY_SIMPLEX, 1,(0,255,0),2,cv2.LINE_AA)
47
+ # for j in range(out['lf'][0].shape[0]):
48
+ begin = out['beginning'][j]
49
+ end = out['ending'][j]
50
+
51
+ last_xy = None
52
+ # for i in xrange(len(out['lf'])):
53
+ begin_f = int(np.floor(begin))
54
+ end_f = int(np.ceil(end))
55
+ for i in range(begin_f, end_f+1):
56
+
57
+ if i == begin_f:
58
+ p0 = out['lf'][i][j].mean(axis=1)
59
+ p1 = out['lf'][i+1][j].mean(axis=1)
60
+ t = begin - np.floor(begin)
61
+ p = p0 * (1 - t) + p1 * t
62
+
63
+ elif i == end_f:
64
+
65
+ p0 = out['lf'][i-1][j].mean(axis=1)
66
+ if i != len(out['lf']):
67
+ p1 = out['lf'][i][j].mean(axis=1)
68
+ t = end - np.floor(end)
69
+ p = p0 * (1 - t) + p1 * t
70
+ else:
71
+ p = p0
72
+ else:
73
+ p = out['lf'][i][j].mean(axis=1)
74
+
75
+
76
+ x = p[0]
77
+ y = p[1]
78
+
79
+ x = int(x)
80
+ y = int(y)
81
+
82
+ # c = int(255 * p[-1])
83
+ # color = (c,0,255-c)
84
+ color = (0,150,0)
85
+ cv2.circle(img,(x,y), 4, color, -1)
86
+
87
+ if last_xy is not None:
88
+ cv2.line(img, (x,y), last_xy, color, int(s))
89
+
90
+ last_xy = (x,y)
91
+ return img
92
+
93
+ def draw_output_original(out, img):
94
+ img = img.copy()
95
+
96
+ for j in range(out['lf'][0].shape[0]):
97
+ begin = out['beginning'][j]
98
+ end = out['ending'][j]
99
+
100
+ last_xy = None
101
+ # for i in xrange(len(out['lf'])):
102
+ begin_f = int(np.floor(begin))
103
+ end_f = int(np.ceil(end))
104
+ for i in range(begin_f, end_f+1):
105
+
106
+ if i == begin_f:
107
+ p0 = out['lf'][i][j].mean(axis=1)
108
+ p1 = out['lf'][i+1][j].mean(axis=1)
109
+ t = begin - np.floor(begin)
110
+ p = p0 * (1 - t) + p1 * t
111
+
112
+ elif i == end_f:
113
+
114
+ p0 = out['lf'][i-1][j].mean(axis=1)
115
+ if i != len(out['lf']):
116
+ p1 = out['lf'][i][j].mean(axis=1)
117
+ t = end - np.floor(end)
118
+ p = p0 * (1 - t) + p1 * t
119
+ else:
120
+ p = p0
121
+ else:
122
+ p = out['lf'][i][j].mean(axis=1)
123
+
124
+
125
+ x = p[0]
126
+ y = p[1]
127
+
128
+ x = int(x)
129
+ y = int(y)
130
+
131
+ color = (0,0,0)
132
+ cv2.circle(img,(x,y), 4, color, -1)
133
+
134
+ if last_xy is not None:
135
+ cv2.line(img, (x,y), last_xy, color, 2)
136
+
137
+ last_xy = (x,y)
138
+
139
+ for i in range(out['sol'].shape[0]):
140
+
141
+ p = out['sol'][i]
142
+
143
+ c = int(255 * p[-1])
144
+ color = (c,0,255-c)
145
+
146
+ x = p[0]
147
+ y = p[1]
148
+ r = p[2]
149
+ x_comp = np.cos(r)
150
+ y_comp = -np.sin(r)
151
+ s = p[3]
152
+
153
+ rx = x + s * x_comp * 2
154
+ ry = y + s * y_comp * 2
155
+
156
+ rx2 = x - s * x_comp
157
+ ry2 = y - s * y_comp
158
+
159
+ rx = int(rx)
160
+ ry = int(ry)
161
+
162
+ rx2 = int(rx2)
163
+ ry2 = int(ry2)
164
+
165
+ x = int(x)
166
+ y = int(y)
167
+ scale = abs(int(s))
168
+
169
+ # color = (0,0,255)
170
+
171
+ cv2.circle(img,(x,y), int(scale), color, 2)
172
+ cv2.circle(img,(x,y), 4, color, -1)
173
+ cv2.arrowedLine(img, (x,y), (rx,ry), color, 2, tipLength=0.25)
174
+ # cv2.line(img, (rx2,ry2), (rx,ry), color, 2)
175
+ cv2.putText(img,str(i),(x,y), cv2.FONT_HERSHEY_SIMPLEX, 1,(0,255,0),2,cv2.LINE_AA)
176
+ return img
py3/hw/__init__.py ADDED
File without changes
py3/hw/cnn_lstm.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ import matplotlib.pyplot as plt
5
+ import matplotlib.patches as patches
6
+
7
+ class BidirectionalLSTM(nn.Module):
8
+
9
+ def __init__(self, nIn, nHidden, nOut):
10
+ super(BidirectionalLSTM, self).__init__()
11
+
12
+ self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True, dropout=0.5, num_layers=2)
13
+ self.embedding = nn.Linear(nHidden * 2, nOut)
14
+
15
+ def forward(self, input):
16
+
17
+ #print('blstm input', input.size())
18
+ recurrent, notused = self.rnn(input)
19
+ #print('rnn output', recurrent.size(), 'not used', notused)
20
+ T, b, h = recurrent.size()
21
+ t_rec = recurrent.view(T * b, h)
22
+
23
+ output = self.embedding(t_rec) # [T * b, nOut]
24
+ output = output.view(T, b, -1)
25
+
26
+ #print('.....', output.size())
27
+ return output
28
+
29
+ class CRNN(nn.Module):
30
+
31
+ def __init__(self, cnnOutSize, nc, nclass, nh, n_rnn=2, leakyRelu=False, use_instance_norm=False):
32
+ super(CRNN, self).__init__()
33
+
34
+ ks = [3, 3, 3, 3, 3, 3, 2]
35
+ ps = [1, 1, 1, 1, 1, 1, 0]
36
+ ss = [1, 1, 1, 1, 1, 1, 1]
37
+ nm = [64, 128, 256, 256, 512, 512, 512]
38
+
39
+ cnn = nn.Sequential()
40
+
41
+ def convRelu(i, batchNormalization=False):
42
+ nIn = nc if i == 0 else nm[i - 1]
43
+ nOut = nm[i]
44
+ cnn.add_module('conv{0}'.format(i),
45
+ nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
46
+ if batchNormalization:
47
+ if not use_instance_norm:
48
+ cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
49
+ else:
50
+ cnn.add_module(f'instancenorm{i}', nn.InstanceNorm2d(nOut))
51
+ if leakyRelu:
52
+ cnn.add_module('relu{0}'.format(i),
53
+ nn.LeakyReLU(0.2, inplace=True))
54
+ else:
55
+ cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
56
+
57
+ convRelu(0)
58
+ cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2)) # 64x16x64
59
+ convRelu(1)
60
+ cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2)) # 128x8x32
61
+ convRelu(2, True)
62
+ convRelu(3)
63
+ cnn.add_module('pooling{0}'.format(2),
64
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 256x4x16
65
+ convRelu(4, True)
66
+ convRelu(5)
67
+ cnn.add_module('pooling{0}'.format(3),
68
+ nn.MaxPool2d((2, 2), (2, 1), (0, 1))) # 512x2x16
69
+ convRelu(6, True) # 512x1x16
70
+
71
+ self.cnn = cnn
72
+ # Mehreen: nclass is the total outut characters. nh is set to 512 by create_model
73
+ self.rnn = BidirectionalLSTM(cnnOutSize, nh, nclass)
74
+ ###MEHREEN ADD PARAM dim=2
75
+ self.softmax = nn.LogSoftmax(dim=2)
76
+
77
+ def forward(self, input):
78
+ conv = self.cnn(input)
79
+ b, c, h, w = conv.size()
80
+ #print('.....', input.size())
81
+ #print('....', b, c, h, w)
82
+
83
+ if torch.any(torch.isnan(conv)):
84
+ print("CONV IS NAN (b,c,h,w) = ", b, c, h, w)
85
+
86
+ #iimg = input.cpu()[0].permute(2, 1, 0)
87
+ #print('....iimg.size', input.size())
88
+ #plt.imshow(iimg)
89
+ #plt.show()
90
+
91
+ ####MEHREEN change this
92
+ #conv = conv.view(b, -1, w) ###<--original
93
+ # to
94
+ conv = torch.reshape(conv, (b, c*h, w))
95
+ ###End mehreen
96
+ conv = conv.permute(2, 0, 1) # [w, b, c]
97
+ # rnn features
98
+ output = self.rnn(conv)
99
+ if torch.any(torch.isnan(output)):
100
+ print("OUTPUT FROM RNN IS NAN")
101
+ ###MEHREEN ADD
102
+ output = self.softmax(output)
103
+ if torch.any(torch.isnan(output)):
104
+ print("OUTPUT FROM SOFTMAX IS NAN")
105
+ if torch.any(torch.isinf(output)):
106
+ print("OUTPUT FROM SOFTMAX IS INF")
107
+ ###END MEHREEN
108
+ return output
109
+
110
+ def create_model(config):
111
+ use_instance_norm = False
112
+ if 'use_instance_norm' in config and config['use_instance_norm']:
113
+ use_instance_norm = True
114
+ crnn = CRNN(config['cnn_out_size'], config['num_of_channels'], config['num_of_outputs'], 512,
115
+ use_instance_norm=use_instance_norm)
116
+ return crnn
117
+
py3/lf/__init__.py ADDED
File without changes
py3/lf/fast_patch_view.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ import sys
4
+
5
+ from utils import transformation_utils
6
+
7
+ def get_patches(image, crop_window, grid_gen, allow_end_early=False, device="cuda"):
8
+
9
+ dtype = torch.FloatTensor
10
+ if 'cuda' in device:
11
+ dtype = torch.cuda.FloatTensor
12
+
13
+ pts = Variable(torch.FloatTensor([
14
+ [-1.0, -1.0, 1.0, 1.0],
15
+ [-1.0, 1.0, -1.0, 1.0],
16
+ [ 1.0, 1.0, 1.0, 1.0]
17
+ ]).type_as(image.data), requires_grad=False)[None,...]
18
+
19
+ bounds = crop_window.matmul(pts)
20
+
21
+ min_bounds, _ = bounds.min(dim=-1)
22
+ max_bounds, _ = bounds.max(dim=-1)
23
+ d_bounds = max_bounds - min_bounds
24
+ floored_idx_offsets = torch.floor(min_bounds[:,:2].data).long()
25
+ max_d_bounds = d_bounds.max(dim=0)[0].max(dim=0)[0]
26
+ crop_size = torch.ceil(max_d_bounds).long()
27
+ if image.is_cuda:
28
+ crop_size = crop_size.cuda()
29
+ w = crop_size.item()
30
+
31
+ memory_space = Variable(torch.zeros(d_bounds.size(0), 3, w, w).type_as(image.data), requires_grad=False)
32
+ translations = []
33
+ N = transformation_utils.compute_renorm_matrix(memory_space)
34
+ all_skipped = True
35
+
36
+ for b_i in range(memory_space.size(0)):
37
+
38
+ o = floored_idx_offsets[b_i]
39
+
40
+ t = Variable(dtype([
41
+ [1,0,-o[0]],
42
+ [0,1,-o[1]],
43
+ [0,0, 1]
44
+ ]), requires_grad=False).expand(3,3)
45
+ translations.append(N.mm(t)[None,...])
46
+
47
+ skip_slice = False
48
+
49
+ s_x = (o[0], o[0]+w)
50
+ s_y = (o[1], o[1]+w)
51
+ t_x = (0, w)
52
+ t_y = (0, w)
53
+ if o[0] < 0:
54
+ s_x = (0, w+o[0])
55
+ t_x = (-o[0], w)
56
+
57
+ if o[1] < 0:
58
+ s_y = (0, w+o[1])
59
+ t_y = (-o[1], w)
60
+
61
+ if o[0]+w >= image.size(2):
62
+ s_x = (s_x[0], image.size(2))
63
+ t_x = (t_x[0], image.size(2) - s_x[0])
64
+
65
+ if o[1]+w >= image.size(3):
66
+ s_y = (s_y[1], image.size(3))
67
+ t_y = (t_y[1], image.size(3) - s_y[1])
68
+
69
+ if s_x[0] >= s_x[1]:
70
+ skip_slice = True
71
+
72
+ if t_x[0] >= t_x[1]:
73
+ skip_slice = True
74
+
75
+ if s_y[0] >= s_y[1]:
76
+ skip_slice = True
77
+
78
+ if t_y[0] >= t_y[1]:
79
+ skip_slice = True
80
+
81
+ if not skip_slice:
82
+ all_skipped = False
83
+ i_s = image[b_i:b_i+1, :, s_x[0]:s_x[1], s_y[0]:s_y[1]]
84
+ memory_space[b_i:b_i+1, :, t_x[0]:t_x[1], t_y[0]:t_y[1]] = i_s
85
+
86
+ if all_skipped and allow_end_early:
87
+ return None
88
+
89
+ translations = torch.cat(translations, 0)
90
+ grid = grid_gen(translations.bmm(crop_window))
91
+ grid = grid[:,:,:,0:2] / grid[:,:,:,2:3]
92
+
93
+ resampled = torch.nn.functional.grid_sample(memory_space.transpose(2,3), grid, mode='bilinear',
94
+ align_corners=True)
95
+
96
+ return resampled
py3/lf/lf_cnn.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ def convRelu(i, batchNormalization=False, leakyRelu=False):
5
+ nc = 3
6
+ ks = [3, 3, 3, 3, 3, 3, 2]
7
+ ps = [1, 1, 1, 1, 1, 1, 1]
8
+ ss = [1, 1, 1, 1, 1, 1, 1]
9
+ nm = [64, 128, 256, 256, 512, 512, 512]
10
+
11
+ cnn = nn.Sequential()
12
+
13
+ nIn = nc if i == 0 else nm[i - 1]
14
+ nOut = nm[i]
15
+ cnn.add_module('conv{0}'.format(i),
16
+ nn.Conv2d(nIn, nOut, ks[i], ss[i], ps[i]))
17
+ if batchNormalization:
18
+ # Mehreen comment: track_running_stat is set to True to be able to load author's state_dict
19
+ # It was set to False in the original py3 version (no param in author's version)
20
+ cnn.add_module('batchnorm{0}'.format(i), nn.InstanceNorm2d(nOut, track_running_stats=True))
21
+ # cnn.add_module('batchnorm{0}'.format(i), nn.BatchNorm2d(nOut))
22
+ if leakyRelu:
23
+ cnn.add_module('relu{0}'.format(i),
24
+ nn.LeakyReLU(0.2, inplace=True))
25
+ else:
26
+ cnn.add_module('relu{0}'.format(i), nn.ReLU(True))
27
+ return cnn
28
+
29
+ def makeCnn():
30
+
31
+ cnn = nn.Sequential()
32
+ cnn.add_module('convRelu{0}'.format(0), convRelu(0))
33
+ cnn.add_module('pooling{0}'.format(0), nn.MaxPool2d(2, 2))
34
+ cnn.add_module('convRelu{0}'.format(1), convRelu(1))
35
+ cnn.add_module('pooling{0}'.format(1), nn.MaxPool2d(2, 2))
36
+ cnn.add_module('convRelu{0}'.format(2), convRelu(2, True))
37
+ cnn.add_module('convRelu{0}'.format(3), convRelu(3))
38
+ cnn.add_module('pooling{0}'.format(2), nn.MaxPool2d(2, 2))
39
+ cnn.add_module('convRelu{0}'.format(4), convRelu(4, True))
40
+ cnn.add_module('convRelu{0}'.format(5), convRelu(5))
41
+ cnn.add_module('pooling{0}'.format(3), nn.MaxPool2d(2, 2))
42
+ cnn.add_module('convRelu{0}'.format(6), convRelu(6, True))
43
+ cnn.add_module('pooling{0}'.format(4), nn.MaxPool2d(2, 2))
44
+
45
+ return cnn
py3/lf/line_follower.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.autograd import Variable
4
+ from .stn.gridgen import AffineGridGen, PerspectiveGridGen, GridGen
5
+ import numpy as np
6
+ from utils import transformation_utils
7
+ from .lf_cnn import makeCnn
8
+ from .fast_patch_view import get_patches
9
+
10
+ class LineFollower(nn.Module):
11
+ def __init__(self, output_grid_size=32, dtype=torch.cuda.FloatTensor, device="cuda"):
12
+ super(LineFollower, self).__init__()
13
+ cnn = makeCnn()
14
+ position_linear = nn.Linear(512,5)
15
+ position_linear.weight.data.zero_()
16
+ position_linear.bias.data[0] = 0
17
+ position_linear.bias.data[1] = 0
18
+ position_linear.bias.data[2] = 0
19
+
20
+ self.output_grid_size = output_grid_size
21
+
22
+ self.dtype = dtype
23
+ self.cnn = cnn
24
+ self.position_linear = position_linear
25
+ self.device = device
26
+
27
+ def forward(self, image, positions, steps=None, all_positions=[], reset_interval=-1, randomize=False, negate_lw=False, skip_grid=False, allow_end_early=False):
28
+
29
+ batch_size = image.size(0)
30
+ renorm_matrix = transformation_utils.compute_renorm_matrix(image)
31
+ expanded_renorm_matrix = renorm_matrix.expand(batch_size,3,3)
32
+
33
+ t = ((np.arange(self.output_grid_size) + 0.5) / float(self.output_grid_size))[:,None].astype(np.float32)
34
+ t = np.repeat(t,axis=1, repeats=self.output_grid_size)
35
+ t = Variable(torch.from_numpy(t), requires_grad=False)
36
+ t = t.to(self.device)
37
+ s = t.t()
38
+
39
+ t = t[:,:,None]
40
+ s = s[:,:,None]
41
+
42
+ interpolations = torch.cat([
43
+ (1-t)*s,
44
+ (1-t)*(1-s),
45
+ t*s,
46
+ t*(1-s),
47
+ ], dim=-1)
48
+
49
+ view_window = Variable(self.dtype([
50
+ [2,0,2],
51
+ [0,2,0],
52
+ [0,0,1]
53
+ ])).expand(batch_size,3,3)
54
+
55
+ step_bias = Variable(self.dtype([
56
+ [1,0,2],
57
+ [0,1,0],
58
+ [0,0,1]
59
+ ])).expand(batch_size,3,3)
60
+
61
+ invert = Variable(self.dtype([
62
+ [-1,0,0],
63
+ [0,-1,0],
64
+ [0,0,1]
65
+ ])).expand(batch_size,3,3)
66
+
67
+ if negate_lw:
68
+ view_window = invert.bmm(view_window)
69
+
70
+ grid_gen = GridGen(32,32, device=self.device)
71
+
72
+ view_window_imgs = []
73
+ next_windows = []
74
+ reset_windows = True
75
+ for i in range(steps):
76
+
77
+ if i%reset_interval != 0 or reset_interval==-1:
78
+ p_0 = positions[-1]
79
+
80
+ if i == 0 and len(p_0.size()) == 3 and p_0.size()[1] == 3 and p_0.size()[2] == 3:
81
+ current_window = p_0
82
+ reset_windows = False
83
+ next_windows.append(p_0)
84
+
85
+ else:
86
+ p_0 = all_positions[i].type(self.dtype)
87
+ reset_windows = True
88
+ if randomize:
89
+ add_noise = p_0.clone()
90
+ add_noise.data.zero_()
91
+ mul_moise = p_0.clone()
92
+ mul_moise.data.fill_(1.0)
93
+
94
+ add_noise[:,0].data.uniform_(-2, 2)
95
+ add_noise[:,1].data.uniform_(-2, 2)
96
+ add_noise[:,2].data.uniform_(-.1, .1)
97
+
98
+ p_0 = p_0 * mul_moise + add_noise
99
+
100
+ if reset_windows:
101
+ reset_windows = False
102
+
103
+ current_window = transformation_utils.get_init_matrix(p_0)
104
+
105
+ if len(next_windows) == 0:
106
+ next_windows.append(current_window)
107
+ else:
108
+ current_window = next_windows[-1].detach()
109
+
110
+ crop_window = current_window.bmm(view_window)
111
+
112
+ resampled = get_patches(image, crop_window, grid_gen, allow_end_early, device=self.device)
113
+
114
+ if resampled is None and i > 0:
115
+ #get patches checks to see if stopping early is allowed
116
+ break
117
+
118
+ if resampled is None and i == 0:
119
+ #Odd case where it start completely off of the edge
120
+ #This happens rarely, but maybe should be more eligantly handled
121
+ #in the future
122
+ resampled = Variable(torch.zeros(crop_window.size(0), 3, 32, 32).type_as(image.data), requires_grad=False)
123
+
124
+
125
+ # Process Window CNN
126
+ cnn_out = self.cnn(resampled)
127
+ cnn_out = torch.squeeze(cnn_out, dim=2)
128
+ cnn_out = torch.squeeze(cnn_out, dim=2)
129
+ delta = self.position_linear(cnn_out)
130
+
131
+
132
+ next_window = transformation_utils.get_step_matrix(delta)
133
+ next_window = next_window.bmm(step_bias)
134
+ if negate_lw:
135
+ next_window = invert.bmm(next_window).bmm(invert)
136
+
137
+ next_windows.append(current_window.bmm(next_window))
138
+
139
+ grid_line = []
140
+ mask_line = []
141
+ line_done = []
142
+ xy_positions = []
143
+
144
+ a_pt = Variable(torch.Tensor(
145
+ [
146
+ [0, 1,1],
147
+ [0,-1,1]
148
+ ]
149
+ )).to(self.device)
150
+ a_pt = a_pt.transpose(1,0)
151
+ a_pt = a_pt.expand(batch_size, a_pt.size(0), a_pt.size(1))
152
+
153
+ for i in range(0, len(next_windows)-1):
154
+
155
+ w_0 = next_windows[i]
156
+ w_1 = next_windows[i+1]
157
+
158
+ pts_0 = w_0.bmm(a_pt)
159
+ pts_1 = w_1.bmm(a_pt)
160
+ xy_positions.append(pts_0)
161
+
162
+ if skip_grid:
163
+ continue
164
+
165
+ pts = torch.cat([pts_0, pts_1], dim=2)
166
+
167
+ grid_pts = expanded_renorm_matrix.bmm(pts)
168
+
169
+ grid = interpolations[None,:,:,None,:] * grid_pts[:,None,None,:,:]
170
+ grid = grid.sum(dim=-1)[...,:2]
171
+
172
+ grid_line.append(grid)
173
+
174
+ xy_positions.append(pts_1)
175
+
176
+ if skip_grid:
177
+ grid_line = None
178
+ else:
179
+ grid_line = torch.cat(grid_line, dim=1)
180
+
181
+ return grid_line, view_window_imgs, next_windows, xy_positions
py3/lf/models/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import torch
3
+
4
+ #from utils1 import coerce_to_path_and_check_exist
5
+
6
+ from .res_unet import ResUNet
7
+ from .tools import safe_model_state_dict
8
+
9
+
10
+ def get_model(name=None):
11
+ if name is None:
12
+ name = 'res_unet18'
13
+ return {
14
+ 'res_unet18': partial(ResUNet, encoder_name='resnet18'),
15
+ 'res_unet34': partial(ResUNet, encoder_name='resnet34'),
16
+ 'res_unet50': partial(ResUNet, encoder_name='resnet50'),
17
+ 'res_unet101': partial(ResUNet, encoder_name='resnet101'),
18
+ 'res_unet152': partial(ResUNet, encoder_name='resnet152'),
19
+ }[name]
20
+
21
+
22
+ def load_model_from_path(model_path, device=None, attributes_to_return=None, eval_mode=True):
23
+ if device is None:
24
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
+ checkpoint = torch.load(coerce_to_path_and_check_exist(model_path), map_location=device.type)
26
+ checkpoint['model_kwargs']['pretrained_encoder'] = False
27
+ model = get_model(checkpoint['model_name'])(checkpoint['n_classes'], **checkpoint['model_kwargs']).to(device)
28
+ model.load_state_dict(safe_model_state_dict(checkpoint['model_state']))
29
+ if eval_mode:
30
+ model.eval()
31
+ if attributes_to_return is not None:
32
+ if isinstance(attributes_to_return, str):
33
+ attributes_to_return = [attributes_to_return]
34
+ return model, [checkpoint.get(key) for key in attributes_to_return]
35
+ else:
36
+ return model
py3/lf/models/res_unet.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+
3
+ from torch import nn
4
+
5
+ from .resnet import get_resnet_model
6
+ from .tools import conv1x1, conv3x3, DecoderModule, get_norm_layer, UpsampleCatConv
7
+ #from utils1.logger import print_info, print_warning
8
+
9
+ INPUT_CHANNELS = 3
10
+ FINAL_LAYER_CHANNELS = 32
11
+ LAYER1_REDUCED_CHANNELS = 128
12
+ LAYER2_REDUCED_CHANNELS = 256
13
+ LAYER3_REDUCED_CHANNELS = 512
14
+ LAYER4_REDUCED_CHANNELS = 1024
15
+
16
+
17
+ class ResUNet(nn.Module):
18
+ """U-Net with residual encoder backbone."""
19
+
20
+ @property
21
+ def name(self):
22
+ return self.enc_name.replace('res', 'res_u')
23
+
24
+ def __init__(self, n_classes, **kwargs):
25
+ super().__init__()
26
+ self.n_classes = n_classes
27
+ self.norm_layer_kwargs = kwargs.pop('norm_layer', dict())
28
+ self.norm_layer = get_norm_layer(**self.norm_layer_kwargs)
29
+ self.no_maxpool = kwargs.get('no_maxpool', False)
30
+ self.conv_as_maxpool = kwargs.get('conv_as_maxpool', True)
31
+ self.use_upcatconv = kwargs.get('use_upcatconv', False)
32
+ self.use_deconv = kwargs.get('use_deconv', True)
33
+ assert not (self.use_deconv and self.use_upcatconv)
34
+ self.same_up_channels = kwargs.get('same_up_channels', False)
35
+ self.use_conv1x1 = kwargs.get('use_conv1x1', False)
36
+ assert not (self.conv_as_maxpool and self.no_maxpool)
37
+ self.enc_name = kwargs.get('encoder_name', 'resnet18')
38
+ self.reduced_layers = kwargs.get('reduced_layers', False) and self.enc_name not in ['resnet18, resnet34']
39
+
40
+ pretrained = kwargs.get('pretrained_encoder', False)
41
+ replace_with_dilation = kwargs.get('replace_with_dilation')
42
+ strides = kwargs.get('strides', 2)
43
+ resnet = get_resnet_model(self.enc_name)(pretrained, progress=False, norm_layer=self.norm_layer_kwargs,
44
+ strides=strides, replace_with_dilation=replace_with_dilation)
45
+
46
+ self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
47
+ # XXX: maxpool creates high amplitude high freq activations, removing it leads to better results
48
+ if self.conv_as_maxpool:
49
+ layer0_out_channels = self.get_nb_out_channels(self.layer0)
50
+ self.layer1 = nn.Sequential(*[conv3x3(layer0_out_channels, layer0_out_channels, stride=2),
51
+ self.norm_layer(layer0_out_channels),
52
+ nn.ReLU()] + list(resnet.layer1.children()))
53
+ elif self.no_maxpool:
54
+ self.layer1 = nn.Sequential(*list(resnet.layer1.children()))
55
+ else:
56
+ self.layer1 = nn.Sequential(*[resnet.maxpool] + list(resnet.layer1.children()))
57
+ self.layer2, self.layer3, self.layer4 = resnet.layer2, resnet.layer3, resnet.layer4
58
+
59
+ layer0_out_channels = self.get_nb_out_channels(self.layer0)
60
+ layer1_out_channels = self.get_nb_out_channels(self.layer1)
61
+ layer2_out_channels = self.get_nb_out_channels(self.layer2)
62
+ layer3_out_channels = self.get_nb_out_channels(self.layer3)
63
+ layer4_out_channels = self.get_nb_out_channels(self.layer4)
64
+ if self.reduced_layers:
65
+ self.layer1_red = self._reducing_layer(layer1_out_channels, LAYER1_REDUCED_CHANNELS)
66
+ self.layer2_red = self._reducing_layer(layer2_out_channels, LAYER2_REDUCED_CHANNELS)
67
+ self.layer3_red = self._reducing_layer(layer3_out_channels, LAYER3_REDUCED_CHANNELS)
68
+ self.layer4_red = self._reducing_layer(layer4_out_channels, LAYER4_REDUCED_CHANNELS)
69
+ layer1_out_channels, layer2_out_channels = LAYER1_REDUCED_CHANNELS, LAYER2_REDUCED_CHANNELS
70
+ layer3_out_channels, layer4_out_channels = LAYER3_REDUCED_CHANNELS, LAYER4_REDUCED_CHANNELS
71
+
72
+ self.layer4_up = self._upsampling_layer(layer4_out_channels, layer3_out_channels, layer3_out_channels)
73
+ self.layer3_up = self._upsampling_layer(layer3_out_channels, layer2_out_channels, layer2_out_channels)
74
+ self.layer2_up = self._upsampling_layer(layer2_out_channels, layer1_out_channels, layer1_out_channels)
75
+ self.layer1_up = self._upsampling_layer(layer1_out_channels, layer0_out_channels, layer0_out_channels)
76
+ self.layer0_up = self._upsampling_layer(layer0_out_channels, FINAL_LAYER_CHANNELS, INPUT_CHANNELS)
77
+ self.final_layer = self._final_layer(FINAL_LAYER_CHANNELS)
78
+
79
+ if not pretrained:
80
+ self._init_conv_weights()
81
+
82
+ print("Model {} initialisated with norm_layer={}({}) and kwargs {}"
83
+ .format(self.name, self.norm_layer.func.__name__, self.norm_layer.keywords, kwargs))
84
+
85
+ def _reducing_layer(self, in_channels, out_channels):
86
+ return nn.Sequential(OrderedDict([
87
+ ('conv', conv1x1(in_channels, out_channels)),
88
+ ('bn', self.norm_layer(out_channels)),
89
+ ('relu', nn.ReLU()),
90
+ ]))
91
+
92
+ def get_nb_out_channels(self, layer):
93
+ return list(filter(lambda e: isinstance(e, nn.Conv2d), layer.modules()))[-1].out_channels
94
+
95
+ def _upsampling_layer(self, in_channels, out_channels, cat_channels):
96
+ if self.use_upcatconv:
97
+ return UpsampleCatConv(in_channels + cat_channels, out_channels, norm_layer=self.norm_layer,
98
+ use_conv1x1=self.use_conv1x1)
99
+ else:
100
+ up_channels = in_channels if self.same_up_channels else None
101
+ return DecoderModule(in_channels, out_channels, cat_channels, up_channels=up_channels,
102
+ norm_layer=self.norm_layer, n_conv=1, use_deconv=self.use_deconv,
103
+ use_conv1x1=self.use_conv1x1)
104
+
105
+ def _final_layer(self, in_channels):
106
+ return nn.Sequential(OrderedDict([('conv', conv1x1(in_channels, self.n_classes))]))
107
+
108
+ def _init_conv_weights(self):
109
+ for m in self.modules():
110
+ if isinstance(m, nn.Conv2d):
111
+ nn.init.xavier_uniform_(m.weight)
112
+
113
+ def load_state_dict_for_unet(self, state_dict):
114
+ unloaded_params = []
115
+ state = self.state_dict()
116
+ for name, param in state_dict.items():
117
+ if name in state and state[name].shape == param.shape:
118
+ if isinstance(param, nn.Parameter):
119
+ param = param.data
120
+ state[name].copy_(param)
121
+ else:
122
+ unloaded_params.append(name)
123
+
124
+ if len(unloaded_params) > 0:
125
+ print('load_state_dict: {} not found'.format(unloaded_params))
126
+
127
+ def forward(self, x):
128
+ x0 = self.layer0(x)
129
+ x1 = self.layer1(x0)
130
+ x2 = self.layer2(x1)
131
+ x3 = self.layer3(x2)
132
+ # x4 = self.layer4(x3)
133
+
134
+ if self.reduced_layers:
135
+ x4 = self.layer4_red(x4)
136
+ x3 = self.layer3_red(x3)
137
+ x2 = self.layer2_red(x2)
138
+ x1 = self.layer1_red(x1)
139
+
140
+ # x3 = self.layer4_up(x4, other=x3)
141
+ x2 = self.layer3_up(x3, other=x2)
142
+ x1 = self.layer2_up(x2, other=x1)
143
+ x0 = self.layer1_up(x1, other=x0)
144
+ x = self.layer0_up(x0, other=x)
145
+ x = self.final_layer(x)
146
+
147
+ return x
py3/lf/models/resnet.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from toolz import keyfilter
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
6
+
7
+ from .tools import conv3x3, conv1x1, get_norm_layer
8
+
9
+ model_urls = {
10
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
11
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
12
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
13
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
14
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
15
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
16
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
17
+ 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
18
+ 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
19
+ }
20
+
21
+
22
+ def get_resnet_model(name):
23
+ if name is None:
24
+ name = 'resnet18'
25
+ return {
26
+ 'resnet18': resnet18,
27
+ 'resnet34': resnet34,
28
+ 'resnet50': resnet50,
29
+ 'resnet101': resnet101,
30
+ 'resnet152': resnet152,
31
+ 'resnext50_32x4d': resnext50_32x4d,
32
+ 'resnext101_32x8d': resnext101_32x8d,
33
+ 'wide_resnet50_2': wide_resnet50_2,
34
+ 'wide_resnet101_2': wide_resnet101_2,
35
+ }[name]
36
+
37
+
38
+ class BasicBlock(nn.Module):
39
+ expansion = 1
40
+
41
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
42
+ base_width=64, dilation=1, norm_layer=None):
43
+ super(BasicBlock, self).__init__()
44
+ if norm_layer is None:
45
+ norm_layer = nn.BatchNorm2d
46
+ if groups != 1 or base_width != 64:
47
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
48
+ if dilation > 1:
49
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
50
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
51
+ self.conv1 = conv3x3(inplanes, planes, stride)
52
+ self.bn1 = norm_layer(planes)
53
+ self.relu = nn.ReLU(inplace=True)
54
+ self.conv2 = conv3x3(planes, planes)
55
+ self.bn2 = norm_layer(planes)
56
+ self.downsample = downsample
57
+ self.stride = stride
58
+
59
+ def forward(self, x):
60
+ identity = x
61
+
62
+ out = self.conv1(x)
63
+ out = self.bn1(out)
64
+ out = self.relu(out)
65
+
66
+ out = self.conv2(out)
67
+ out = self.bn2(out)
68
+
69
+ if self.downsample is not None:
70
+ identity = self.downsample(x)
71
+
72
+ out += identity
73
+ out = self.relu(out)
74
+
75
+ return out
76
+
77
+
78
+ class Bottleneck(nn.Module):
79
+ expansion = 4
80
+
81
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
82
+ base_width=64, dilation=1, norm_layer=None):
83
+ super(Bottleneck, self).__init__()
84
+ if norm_layer is None:
85
+ norm_layer = nn.BatchNorm2d
86
+ width = int(planes * (base_width / 64.)) * groups
87
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
88
+ self.conv1 = conv1x1(inplanes, width)
89
+ self.bn1 = norm_layer(width)
90
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
91
+ self.bn2 = norm_layer(width)
92
+ self.conv3 = conv1x1(width, planes * self.expansion)
93
+ self.bn3 = norm_layer(planes * self.expansion)
94
+ self.relu = nn.ReLU(inplace=True)
95
+ self.downsample = downsample
96
+ self.stride = stride
97
+
98
+ def forward(self, x):
99
+ identity = x
100
+
101
+ out = self.conv1(x)
102
+ out = self.bn1(out)
103
+ out = self.relu(out)
104
+
105
+ out = self.conv2(out)
106
+ out = self.bn2(out)
107
+ out = self.relu(out)
108
+
109
+ out = self.conv3(out)
110
+ out = self.bn3(out)
111
+
112
+ if self.downsample is not None:
113
+ identity = self.downsample(x)
114
+
115
+ out += identity
116
+ out = self.relu(out)
117
+
118
+ return out
119
+
120
+
121
+ class ResNet(nn.Module):
122
+
123
+ def __init__(self, block, layers, n_classes=1000, zero_init_residual=False,
124
+ groups=1, width_per_group=64, strides=2, replace_with_dilation=None, **kwargs):
125
+ super(ResNet, self).__init__()
126
+ self.norm_layer_kwargs = kwargs.get('norm_layer', dict())
127
+ norm_layer = get_norm_layer(**self.norm_layer_kwargs)
128
+ self._norm_layer = norm_layer
129
+ self.inplanes = 64
130
+ self.groups = groups
131
+ self.base_width = width_per_group
132
+ self.dilation = 1
133
+ if replace_with_dilation is None:
134
+ # each element in the tuple indicates if we should replace
135
+ # the 2x2 stride with a dilated convolution instead
136
+ replace_with_dilation = [False, False, False]
137
+ elif isinstance(replace_with_dilation, bool):
138
+ replace_with_dilation = [replace_with_dilation] * 3
139
+ assert len(replace_with_dilation) == 3
140
+ self.strides = strides if not isinstance(strides, int) else [strides] * 5
141
+ assert len(self.strides) == 5
142
+
143
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=self.strides[0], padding=3, bias=False)
144
+ self.bn1 = norm_layer(self.inplanes)
145
+ self.relu = nn.ReLU(inplace=True)
146
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=self.strides[1], padding=1)
147
+ self.layer1 = self._make_layer(block, 64, layers[0])
148
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=self.strides[2], dilate=replace_with_dilation[0])
149
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=self.strides[3], dilate=replace_with_dilation[1])
150
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=self.strides[4], dilate=replace_with_dilation[2])
151
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
152
+ self.fc = nn.Linear(512 * block.expansion, n_classes)
153
+
154
+ for m in self.modules():
155
+ if isinstance(m, nn.Conv2d):
156
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
157
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
158
+ nn.init.constant_(m.weight, 1)
159
+ nn.init.constant_(m.bias, 0)
160
+
161
+ # Zero-initialize the last BN in each residual branch,
162
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
163
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
164
+ if zero_init_residual:
165
+ for m in self.modules():
166
+ if isinstance(m, Bottleneck):
167
+ nn.init.constant_(m.bn3.weight, 0)
168
+ elif isinstance(m, BasicBlock):
169
+ nn.init.constant_(m.bn2.weight, 0)
170
+
171
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
172
+ norm_layer = self._norm_layer
173
+ downsample = None
174
+ previous_dilation = self.dilation
175
+ if dilate:
176
+ self.dilation *= stride
177
+ stride = 1
178
+ if stride != 1 or self.inplanes != planes * block.expansion:
179
+ downsample = nn.Sequential(
180
+ conv1x1(self.inplanes, planes * block.expansion, stride),
181
+ norm_layer(planes * block.expansion),
182
+ )
183
+
184
+ layers = []
185
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
186
+ self.base_width, previous_dilation, norm_layer))
187
+ self.inplanes = planes * block.expansion
188
+ for _ in range(1, blocks):
189
+ layers.append(block(self.inplanes, planes, groups=self.groups,
190
+ base_width=self.base_width, dilation=self.dilation,
191
+ norm_layer=norm_layer))
192
+
193
+ return nn.Sequential(*layers)
194
+
195
+ def forward(self, x):
196
+ x = self.conv1(x)
197
+ x = self.bn1(x)
198
+ x = self.relu(x)
199
+ x = self.maxpool(x)
200
+
201
+ x = self.layer1(x)
202
+ x = self.layer2(x)
203
+ x = self.layer3(x)
204
+ x = self.layer4(x)
205
+
206
+ x = self.avgpool(x)
207
+ x = torch.flatten(x, 1)
208
+ x = self.fc(x)
209
+
210
+ return x
211
+
212
+
213
+ def _resnet(arch, block, layers, pretrained, progress, **kwargs):
214
+ model = ResNet(block, layers, **kwargs)
215
+ if pretrained:
216
+ state_dict = load_state_dict_from_url(model_urls[arch], progress=progress)
217
+ if not model.norm_layer_kwargs.get('track_running_stats', True):
218
+ state_dict = keyfilter(lambda k: 'running' not in k, state_dict)
219
+ model.load_state_dict(state_dict)
220
+ return model
221
+
222
+
223
+ def resnet18(pretrained=False, progress=True, **kwargs):
224
+ r"""ResNet-18 model from
225
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
226
+
227
+ Args:
228
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
229
+ progress (bool): If True, displays a progress bar of the download to stderr
230
+ """
231
+ return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, **kwargs)
232
+
233
+
234
+ def resnet34(pretrained=False, progress=True, **kwargs):
235
+ r"""ResNet-34 model from
236
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
237
+
238
+ Args:
239
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
240
+ progress (bool): If True, displays a progress bar of the download to stderr
241
+ """
242
+ return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, **kwargs)
243
+
244
+
245
+ def resnet50(pretrained=False, progress=True, **kwargs):
246
+ r"""ResNet-50 model from
247
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
248
+
249
+ Args:
250
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
251
+ progress (bool): If True, displays a progress bar of the download to stderr
252
+ """
253
+ return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
254
+
255
+
256
+ def resnet101(pretrained=False, progress=True, **kwargs):
257
+ r"""ResNet-101 model from
258
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
259
+
260
+ Args:
261
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
262
+ progress (bool): If True, displays a progress bar of the download to stderr
263
+ """
264
+ return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
265
+
266
+
267
+ def resnet152(pretrained=False, progress=True, **kwargs):
268
+ r"""ResNet-152 model from
269
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>'_
270
+
271
+ Args:
272
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
273
+ progress (bool): If True, displays a progress bar of the download to stderr
274
+ """
275
+ return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, **kwargs)
276
+
277
+
278
+ def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
279
+ r"""ResNeXt-50 32x4d model from
280
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
281
+
282
+ Args:
283
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
284
+ progress (bool): If True, displays a progress bar of the download to stderr
285
+ """
286
+ kwargs['groups'] = 32
287
+ kwargs['width_per_group'] = 4
288
+ return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
289
+
290
+
291
+ def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
292
+ r"""ResNeXt-101 32x8d model from
293
+ `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
294
+
295
+ Args:
296
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
297
+ progress (bool): If True, displays a progress bar of the download to stderr
298
+ """
299
+ kwargs['groups'] = 32
300
+ kwargs['width_per_group'] = 8
301
+ return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
302
+
303
+
304
+ def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
305
+ r"""Wide ResNet-50-2 model from
306
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
307
+
308
+ The model is the same as ResNet except for the bottleneck number of channels
309
+ which is twice larger in every block. The number of channels in outer 1x1
310
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
311
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
312
+
313
+ Args:
314
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
315
+ progress (bool): If True, displays a progress bar of the download to stderr
316
+ """
317
+ kwargs['width_per_group'] = 64 * 2
318
+ return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], pretrained, progress, **kwargs)
319
+
320
+
321
+ def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
322
+ r"""Wide ResNet-101-2 model from
323
+ `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
324
+
325
+ The model is the same as ResNet except for the bottleneck number of channels
326
+ which is twice larger in every block. The number of channels in outer 1x1
327
+ convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
328
+ channels, and in Wide ResNet-50-2 has 2048-1024-2048.
329
+
330
+ Args:
331
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
332
+ progress (bool): If True, displays a progress bar of the download to stderr
333
+ """
334
+ kwargs['width_per_group'] = 64 * 2
335
+ return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], pretrained, progress, **kwargs)
py3/lf/models/tools.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ from functools import partial
3
+
4
+ from torch import nn, cat
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def get_norm_layer(**kwargs):
9
+ name = kwargs.get('name', 'instance_norm')
10
+ momentum = kwargs.get('momentum', 0.1)
11
+ affine = kwargs.get('affine', True)
12
+ track_stats = kwargs.get('track_running_stats', False)
13
+ num_groups = kwargs.get('num_groups', 32)
14
+
15
+ norm_layer = {
16
+ 'batch_norm': partial(nn.BatchNorm2d, momentum=momentum, affine=affine, track_running_stats=track_stats),
17
+ 'group_norm': partial(nn.GroupNorm, num_groups=num_groups, affine=affine),
18
+ 'instance_norm': partial(nn.InstanceNorm2d, momentum=momentum, affine=affine, track_running_stats=track_stats),
19
+ }[name]
20
+ if norm_layer.func == nn.GroupNorm:
21
+ return lambda num_channels: norm_layer(num_channels=num_channels)
22
+ else:
23
+ return norm_layer
24
+
25
+
26
+ def initialize_weights(*models):
27
+ for model in models:
28
+ for module in model.modules():
29
+ if isinstance(module, nn.Conv2d) or isinstance(module, nn.Linear):
30
+ nn.init.kaiming_normal_(module.weight)
31
+ if module.bias is not None:
32
+ module.bias.data.zero_()
33
+ elif isinstance(module, nn.BatchNorm2d):
34
+ module.weight.data.fill_(1)
35
+ module.bias.data.zero_()
36
+
37
+
38
+ def count_parameters(model):
39
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
40
+
41
+
42
+ def safe_model_state_dict(state_dict):
43
+ """Convert a state dict saved from a DataParallel module to normal module state_dict."""
44
+ if not next(iter(state_dict)).startswith("module."):
45
+ return state_dict # abort if dict is not a DataParallel model_state
46
+ new_state_dict = OrderedDict()
47
+ for k, v in state_dict.items():
48
+ new_state_dict[k[7:]] = v # remove 'module.' prefix
49
+ return new_state_dict
50
+
51
+
52
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
53
+ """3x3 convolution with padding"""
54
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
55
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
56
+
57
+
58
+ def conv1x1(in_planes, out_planes, stride=1):
59
+ """1x1 convolution"""
60
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
61
+
62
+
63
+ class UpsampleCatConv(nn.Module):
64
+ def __init__(self, in_channels, out_channels, norm_layer=None, mode='bilinear', use_conv1x1=False):
65
+ super().__init__()
66
+ norm_layer = norm_layer if norm_layer is not None else nn.BatchNorm2d
67
+ conv_layer = conv1x1 if use_conv1x1 else conv3x3
68
+ self.mode = mode
69
+ self.conv = conv_layer(in_channels, out_channels)
70
+ self.norm = norm_layer(out_channels)
71
+ self.act = nn.ReLU()
72
+
73
+ def forward(self, x, other):
74
+ x = nn.functional.interpolate(x, size=(other.size(2), other.size(3)), mode=self.mode, align_corners=False)
75
+ x = cat((x, other), dim=1)
76
+ x = self.conv(x)
77
+ x = self.norm(x)
78
+ x = self.act(x)
79
+ return x
80
+
81
+
82
+ class UpsampleConv(nn.Module):
83
+ def __init__(self, in_channels, out_channels, norm_layer=None, mode='bilinear', use_conv1x1=False):
84
+ super().__init__()
85
+ norm_layer = norm_layer if norm_layer is not None else nn.BatchNorm2d
86
+ conv_layer = conv1x1 if use_conv1x1 else conv3x3
87
+ self.mode = mode
88
+ self.conv = conv_layer(in_channels, out_channels)
89
+ self.norm = norm_layer(out_channels)
90
+ self.act = nn.ReLU()
91
+
92
+ def forward(self, x, output_size):
93
+ x = nn.functional.interpolate(x, size=output_size[2:], mode=self.mode, align_corners=False)
94
+ x = self.conv(x)
95
+ x = self.norm(x)
96
+ x = self.act(x)
97
+ return x
98
+
99
+
100
+ class DeconvModule(nn.Module):
101
+ def __init__(self, in_channels, out_channels, norm_layer=None):
102
+ super().__init__()
103
+ norm_layer = norm_layer if norm_layer is not None else nn.BatchNorm2d
104
+ self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
105
+ self.norm = norm_layer(out_channels)
106
+ self.act = nn.ReLU()
107
+
108
+ def forward(self, x, output_size):
109
+ x = self.deconv(x, output_size=output_size)
110
+ x = self.norm(x)
111
+ x = self.act(x)
112
+ return x
113
+
114
+
115
+ class DecoderModule(nn.Module):
116
+ def __init__(self, in_channels, out_channels, cat_channels=None, up_channels=None,
117
+ norm_layer=None, n_conv=2, use_deconv=False, use_conv1x1=False):
118
+ super().__init__()
119
+ cat_channels = cat_channels or in_channels // 2
120
+ up_channels = up_channels or in_channels // 2
121
+ norm_layer = norm_layer if norm_layer is not None else nn.BatchNorm2d
122
+ self.use_deconv = use_deconv
123
+ if use_deconv:
124
+ self.decode = DeconvModule(in_channels, up_channels, norm_layer)
125
+ else:
126
+ self.decode = UpsampleConv(in_channels, up_channels, norm_layer, 'bilinear', use_conv1x1)
127
+ self.conv_block = nn.Sequential(OrderedDict(sum([[
128
+ ('conv{}'.format(k + 1), conv3x3(up_channels + cat_channels if k == 0 else out_channels, out_channels)),
129
+ ('bn{}'.format(k + 1), norm_layer(out_channels)),
130
+ ('relu{}'.format(k + 1), nn.ReLU())]
131
+ for k in range(n_conv)], [])))
132
+
133
+ def forward(self, x, other):
134
+ try:
135
+ x = self.decode(x, output_size=other.size())
136
+ except ValueError:
137
+ # XXX a size adjustement is needed for odd sizes
138
+ B, C, H, W = other.size()
139
+ h, w = H // 2 * 2, W // 2 * 2
140
+ x = self.decode(x, output_size=(B, C, h, w))
141
+ x = F.pad(x, (W - w, 0, H - h, 0))
142
+ x = cat((x, other), dim=1)
143
+ x = self.conv_block(x)
144
+ return x
py3/lf/stn/__init__.py ADDED
File without changes
py3/lf/stn/gridgen.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Function
3
+ from torch.autograd import Variable
4
+ from torch.nn.modules.module import Module
5
+ import numpy as np
6
+
7
+
8
+
9
+ class AffineGridGenFunction(Function):
10
+ def __init__(self, height, width):
11
+ super(AffineGridGenFunction, self).__init__()
12
+ self.height, self.width = height, width
13
+ self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
14
+ self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.height), 0), repeats = self.width, axis = 0).T, 0)
15
+ self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.arange(-1, 1, 2.0/self.width), 0), repeats = self.height, axis = 0), 0)
16
+ self.grid[:,:,2] = np.ones([self.height, width])
17
+ self.grid = torch.from_numpy(self.grid.astype(np.float32))
18
+ #print(self.grid)
19
+
20
+ def forward(self, input1):
21
+ self.input1 = input1
22
+ output = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
23
+ self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
24
+ for i in range(input1.size(0)):
25
+ self.batchgrid[i] = self.grid
26
+
27
+ if input1.is_cuda:
28
+ self.batchgrid = self.batchgrid.cuda()
29
+ output = output.cuda()
30
+
31
+ for i in range(input1.size(0)):
32
+ output = torch.bmm(self.batchgrid.view(-1, self.height*self.width, 3), torch.transpose(input1, 1, 2)).view(-1, self.height, self.width, 2)
33
+
34
+ return output
35
+
36
+ def backward(self, grad_output):
37
+
38
+ grad_input1 = torch.zeros(self.input1.size())
39
+
40
+ if grad_output.is_cuda:
41
+ self.batchgrid = self.batchgrid.cuda()
42
+ grad_input1 = grad_input1.cuda()
43
+ grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 2), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
44
+
45
+ return grad_input1
46
+
47
+ class AffineGridGen(Module):
48
+ def __init__(self, height, width):
49
+ super(AffineGridGen, self).__init__()
50
+ self.height, self.width = height, width
51
+ self.f = AffineGridGenFunction(self.height, self.width)
52
+
53
+ def forward(self, input):
54
+ return self.f(input)
55
+
56
+
57
+ class PerspectiveGridGenFunction(Function):
58
+ def __init__(self, height, width):
59
+ super(PerspectiveGridGenFunction, self).__init__()
60
+ self.height, self.width = height, width
61
+ self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
62
+ self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.linspace(-1, 1, self.height), 0), repeats = self.width, axis = 0).T, 0)
63
+ self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.linspace(-1, 1, self.width), 0), repeats = self.height, axis = 0), 0)
64
+ self.grid[:,:,2] = np.ones([self.height, width])
65
+ self.grid = torch.from_numpy(self.grid.astype(np.float32))
66
+
67
+ def forward(self, input1):
68
+ self.input1 = input1
69
+ output = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
70
+ self.batchgrid = torch.zeros(torch.Size([input1.size(0)]) + self.grid.size())
71
+ for i in range(input1.size(0)):
72
+ self.batchgrid[i] = self.grid
73
+
74
+ if input1.is_cuda:
75
+ self.batchgrid = self.batchgrid.cuda()
76
+ output = output.cuda()
77
+
78
+ for i in range(input1.size(0)):
79
+ output = torch.bmm(self.batchgrid.view(-1, self.height*self.width, 3), torch.transpose(input1, 1, 2)).view(-1, self.height, self.width, 3)
80
+
81
+ return output
82
+
83
+ def backward(self, grad_output):
84
+
85
+ grad_input1 = torch.zeros(self.input1.size())
86
+
87
+ if grad_output.is_cuda:
88
+ self.batchgrid = self.batchgrid.cuda()
89
+ grad_input1 = grad_input1.cuda()
90
+ grad_input1 = torch.baddbmm(grad_input1, torch.transpose(grad_output.view(-1, self.height*self.width, 3), 1,2), self.batchgrid.view(-1, self.height*self.width, 3))
91
+
92
+ return grad_input1
93
+
94
+ class PerspectiveGridGen(Module):
95
+ def __init__(self, height, width):
96
+ super(PerspectiveGridGen, self).__init__()
97
+ self.height, self.width = height, width
98
+ self.f = PerspectiveGridGenFunction(self.height, self.width)
99
+
100
+ def forward(self, input):
101
+ return self.f(input)
102
+
103
+ class GridGen(Module):
104
+ def __init__(self, height, width, device="cuda"):
105
+ super(GridGen, self).__init__()
106
+ self.device = device
107
+ self.height, self.width = height, width
108
+ self.grid = np.zeros( [self.height, self.width, 3], dtype=np.float32)
109
+
110
+ grid_space_h = (np.arange(self.height) + 0.5) / float(self.height)
111
+ grid_space_w = (np.arange(self.width) + 0.5) / float(self.width)
112
+
113
+ grid_space_h = 2 * grid_space_h - 1
114
+ grid_space_w = 2 * grid_space_w - 1
115
+
116
+ self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(grid_space_h, 0), repeats = self.width, axis = 0).T, 0)
117
+ self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(grid_space_w, 0), repeats = self.height, axis = 0), 0)
118
+ # self.grid[:,:,0] = np.expand_dims(np.repeat(np.expand_dims(np.linspace(-1, 1, self.height), 0), repeats = self.width, axis = 0).T, 0)
119
+ # self.grid[:,:,1] = np.expand_dims(np.repeat(np.expand_dims(np.linspace(-1, 1, self.width), 0), repeats = self.height, axis = 0), 0)
120
+ self.grid[:,:,2] = np.ones([self.height, width])
121
+ self.grid = Variable(torch.from_numpy(self.grid.astype(np.float32)), requires_grad=False).to(self.device)
122
+
123
+ def forward(self, input):
124
+ out = torch.matmul(input[:,None,None,:,:], self.grid[None,:,:,:,None])
125
+ out = out.squeeze(-1)
126
+ return out
py3/sol/__init__.py ADDED
File without changes
py3/sol/crop_transform.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from sol import crop_utils
3
+ import numpy as np
4
+
5
+ class CropTransform(object):
6
+ def __init__(self, crop_params):
7
+ crop_size = crop_params['crop_size']
8
+ self.random_crop_params = crop_params
9
+ self.pad_params = ((crop_size,crop_size),(crop_size,crop_size),(0,0))
10
+
11
+ def __call__(self, sample):
12
+ org_img = sample['img']
13
+ gt = sample['sol_gt']
14
+
15
+ org_img = np.pad(org_img, self.pad_params, 'mean')
16
+
17
+ gt[:,:,0] = gt[:,:,0] + self.pad_params[0][0]
18
+ gt[:,:,1] = gt[:,:,1] + self.pad_params[1][0]
19
+
20
+ gt[:,:,2] = gt[:,:,2] + self.pad_params[0][0]
21
+ gt[:,:,3] = gt[:,:,3] + self.pad_params[1][0]
22
+
23
+ crop_params, org_img, gt_match = crop_utils.generate_random_crop(org_img, gt, self.random_crop_params)
24
+
25
+ gt = gt[gt_match][None,...]
26
+ gt[...,0] = gt[...,0] - crop_params['dim1'][0]
27
+ gt[...,1] = gt[...,1] - crop_params['dim0'][0]
28
+
29
+ gt[...,2] = gt[...,2] - crop_params['dim1'][0]
30
+ gt[...,3] = gt[...,3] - crop_params['dim0'][0]
31
+
32
+ return {
33
+ "img": org_img,
34
+ "sol_gt": gt
35
+ }
py3/sol/crop_utils.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import sys
4
+
5
+ def perform_crop(img, crop):
6
+ cs = crop['crop_size']
7
+ cropped_gt_img = img[crop['dim0'][0]:crop['dim0'][1], crop['dim1'][0]:crop['dim1'][1]]
8
+ scaled_gt_img = cv2.resize(cropped_gt_img, (cs, cs), interpolation = cv2.INTER_CUBIC)
9
+ return scaled_gt_img
10
+
11
+
12
+ def generate_random_crop(img, gt, params):
13
+
14
+ contains_label = np.random.random() < params['prob_label']
15
+ cs = params['crop_size']
16
+
17
+ cnt = 0
18
+ while True:
19
+
20
+ dim0 = np.random.randint(0,img.shape[0]-cs)
21
+ dim1 = np.random.randint(0,img.shape[1]-cs)
22
+
23
+ crop = {
24
+ "dim0": [dim0, dim0+cs],
25
+ "dim1": [dim1, dim1+cs],
26
+ "crop_size": cs
27
+ }
28
+
29
+ #TODO: this only works for the center points
30
+ gt_match = np.zeros_like(gt[...,0:2])
31
+ gt_match[...,0][gt[...,0] < dim1] = 1
32
+ gt_match[...,0][gt[...,0] > dim1+cs] = 1
33
+
34
+ gt_match[...,1][gt[...,1] < dim0] = 1
35
+ gt_match[...,1][gt[...,1] > dim0+cs] = 1
36
+
37
+ gt_match = 1-gt_match
38
+ gt_match = np.logical_and(gt_match[...,0], gt_match[...,1])
39
+
40
+ if gt_match.sum() > 0 and contains_label or cnt > 100:
41
+ cropped_gt_img = perform_crop(img, crop)
42
+ return crop, cropped_gt_img, np.where(gt_match != 0)
43
+
44
+ if gt_match.sum() == 0 and not contains_label:
45
+ cropped_gt_img = perform_crop(img, crop)
46
+ return crop, cropped_gt_img, np.where(gt_match != 0)
47
+
48
+ cnt += 1
py3/sol/start_of_line_finder.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.autograd import Variable
3
+ from torch import nn
4
+ from . import vgg
5
+
6
+ class StartOfLineFinder(nn.Module):
7
+ def __init__(self, base_0, base_1):
8
+ super(StartOfLineFinder, self).__init__()
9
+
10
+ self.cnn = vgg.vgg11()
11
+ self.base_0 = base_0
12
+ self.base_1 = base_1
13
+
14
+ def forward(self, img):
15
+ y = self.cnn(img)
16
+ #print('sol input is image of size', img.size())
17
+ #print('sol forward is output of size', y.size())
18
+
19
+ priors_0 = Variable(torch.arange(0,y.size(2)).type_as(img.data), requires_grad=False)[None,:,None]
20
+ priors_0 = (priors_0 + 0.5) * self.base_0
21
+ priors_0 = priors_0.expand(y.size(0), priors_0.size(1), y.size(3))
22
+ priors_0 = priors_0[:,None,:,:]
23
+
24
+ priors_1 = Variable(torch.arange(0,y.size(3)).type_as(img.data), requires_grad=False)[None,None,:]
25
+ priors_1 = (priors_1 + 0.5) * self.base_1
26
+ priors_1 = priors_1.expand(y.size(0), y.size(2), priors_1.size(2))
27
+ priors_1 = priors_1[:,None,:,:]
28
+
29
+ predictions = torch.cat([
30
+ torch.sigmoid(y[:,0:1,:,:]),
31
+ y[:,1:2,:,:] + priors_0,
32
+ y[:,2:3,:,:] + priors_1,
33
+ y[:,3:4,:,:],
34
+ y[:,4:5,:,:]
35
+ ], dim=1)
36
+
37
+ predictions = predictions.transpose(1,3).contiguous()
38
+ predictions = predictions.view(predictions.size(0),-1,5)
39
+
40
+ #print('priors_0', priors_0)
41
+ #print('sol final prediction is size', predictions.size())
42
+ return predictions
py3/sol/vgg.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import math
4
+
5
+
6
+ __all__ = [
7
+ 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
8
+ 'vgg19_bn', 'vgg19',
9
+ ]
10
+
11
+
12
+ class VGG(nn.Module):
13
+
14
+ def __init__(self, features, num_classes=1000):
15
+ super(VGG, self).__init__()
16
+ self.features = features
17
+
18
+ def forward(self, x):
19
+ x = self.features(x)
20
+ return x
21
+
22
+ def _initialize_weights(self):
23
+ for m in self.modules():
24
+ if isinstance(m, nn.Conv2d):
25
+ n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
26
+ m.weight.data.normal_(0, math.sqrt(2. / n))
27
+ if m.bias is not None:
28
+ m.bias.data.zero_()
29
+ elif isinstance(m, nn.BatchNorm2d):
30
+ m.weight.data.fill_(1)
31
+ m.bias.data.zero_()
32
+ elif isinstance(m, nn.Linear):
33
+ m.weight.data.normal_(0, 0.01)
34
+ m.bias.data.zero_()
35
+
36
+
37
+ def make_layers(cfg, batch_norm=False):
38
+ layers = []
39
+ in_channels = 3
40
+ for i,v in enumerate(cfg):
41
+ if v == 'M':
42
+ layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
43
+ else:
44
+ conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
45
+ if i == len(cfg)-1:
46
+ layers += [conv2d]
47
+ break
48
+ if batch_norm:
49
+ layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
50
+ else:
51
+ layers += [conv2d, nn.ReLU(inplace=True)]
52
+ in_channels = v
53
+ return nn.Sequential(*layers)
54
+
55
+ OUTPUT_FEATURES = 5
56
+ cfg = {
57
+ 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, OUTPUT_FEATURES],
58
+ 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, OUTPUT_FEATURES],
59
+ 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, OUTPUT_FEATURES],
60
+ 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, OUTPUT_FEATURES],
61
+ }
62
+
63
+
64
+ def vgg11(pretrained=False, **kwargs):
65
+ """VGG 11-layer model (configuration "A")
66
+
67
+ Args:
68
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
69
+ """
70
+ model = VGG(make_layers(cfg['A']), **kwargs)
71
+ if pretrained:
72
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
73
+ return model
74
+
75
+
76
+ def vgg11_bn(pretrained=False, **kwargs):
77
+ """VGG 11-layer model (configuration "A") with batch normalization
78
+
79
+ Args:
80
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
81
+ """
82
+ model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
83
+ if pretrained:
84
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
85
+ return model
86
+
87
+
88
+ def vgg13(pretrained=False, **kwargs):
89
+ """VGG 13-layer model (configuration "B")
90
+
91
+ Args:
92
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
93
+ """
94
+ model = VGG(make_layers(cfg['B']), **kwargs)
95
+ if pretrained:
96
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
97
+ return model
98
+
99
+
100
+ def vgg13_bn(pretrained=False, **kwargs):
101
+ """VGG 13-layer model (configuration "B") with batch normalization
102
+
103
+ Args:
104
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
105
+ """
106
+ model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
107
+ if pretrained:
108
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
109
+ return model
110
+
111
+
112
+ def vgg16(pretrained=False, **kwargs):
113
+ """VGG 16-layer model (configuration "D")
114
+
115
+ Args:
116
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
117
+ """
118
+ model = VGG(make_layers(cfg['D']), **kwargs)
119
+ if pretrained:
120
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
121
+ return model
122
+
123
+
124
+ def vgg16_bn(pretrained=False, **kwargs):
125
+ """VGG 16-layer model (configuration "D") with batch normalization
126
+
127
+ Args:
128
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
129
+ """
130
+ model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
131
+ if pretrained:
132
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
133
+ return model
134
+
135
+
136
+ def vgg19(pretrained=False, **kwargs):
137
+ """VGG 19-layer model (configuration "E")
138
+
139
+ Args:
140
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
141
+ """
142
+ model = VGG(make_layers(cfg['E']), **kwargs)
143
+ if pretrained:
144
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
145
+ return model
146
+
147
+
148
+ def vgg19_bn(pretrained=False, **kwargs):
149
+ """VGG 19-layer model (configuration 'E') with batch normalization
150
+
151
+ Args:
152
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
153
+ """
154
+ model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
155
+ if pretrained:
156
+ model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
157
+ return model
py3/utils/__init__.py ADDED
File without changes
py3/utils/character_set.ipynb ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "a0f742b0",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "/home/msaeed3/mehreen/source/start_follow_read/py3\n",
14
+ "....c ب\n",
15
+ "....c أ\n",
16
+ "....c ن\n",
17
+ "....c \n",
18
+ "....c ا\n",
19
+ "....c ل\n",
20
+ "....c ت\n",
21
+ "....c ع\n",
22
+ "....c و\n",
23
+ "....c ي\n",
24
+ "....c ٢\n",
25
+ "....c س\n",
26
+ "....c ط\n",
27
+ "....c ج\n",
28
+ "....c ه\n",
29
+ "....c ز\n",
30
+ "....c ف\n",
31
+ "....c ذ\n",
32
+ "....c ٳ\n",
33
+ "....c ك\n",
34
+ "....c ٪\n",
35
+ "....c ٜ\n",
36
+ "....c ١\n",
37
+ "....c د\n",
38
+ "....c ة\n",
39
+ "....c م\n",
40
+ "....c ٬\n",
41
+ "....c ق\n",
42
+ "....c ر\n",
43
+ "....c ش\n",
44
+ "....c ٚ\n",
45
+ "....c ح\n",
46
+ "....c -\n",
47
+ "....c !\n",
48
+ "....c ص\n",
49
+ "....c ض\n",
50
+ "....c ٝ\n",
51
+ "....c ث\n",
52
+ "....c 7\n",
53
+ "....c ٙ\n",
54
+ "....c ٨\n",
55
+ "....c ک\n",
56
+ "....c ٮ\n",
57
+ "....c ڤ\n",
58
+ "....c ٤\n",
59
+ "....c ں\n",
60
+ "....c 4\n",
61
+ "....c ٰ\n",
62
+ "....c ]\n",
63
+ "....c ڡ\n",
64
+ "....c ى\n",
65
+ "....c \\\n",
66
+ "....c 2\n",
67
+ "....c ٛ\n",
68
+ "....c =\n",
69
+ "....c إ\n",
70
+ "....c غ\n",
71
+ "....c ٲ\n",
72
+ "....c ّ\n",
73
+ "....c >\n",
74
+ "....c .\n",
75
+ "....c )\n",
76
+ "....c <\n",
77
+ "....c ئ\n",
78
+ "....c |\n",
79
+ "....c 0\n",
80
+ "....c +\n",
81
+ "....c x\n",
82
+ "....c ؟\n",
83
+ "....c خ\n",
84
+ "....c }\n",
85
+ "....c &\n",
86
+ "....c %\n",
87
+ "....c ،\n",
88
+ "....c @\n",
89
+ "....c $\n",
90
+ "....c ء\n",
91
+ "....c 5\n",
92
+ "....c 8\n",
93
+ "....c _\n",
94
+ "....c ٌ\n",
95
+ "....c ×\n",
96
+ "....c ^\n",
97
+ "....c ٍ\n",
98
+ "....c `\n",
99
+ "....c [\n",
100
+ "....c آ\n",
101
+ "....c َ\n",
102
+ "....c ;\n",
103
+ "....c ً\n",
104
+ "....c ُ\n",
105
+ "....c /\n",
106
+ "....c ٕ\n",
107
+ "....c ~\n",
108
+ "....c \"\n",
109
+ "....c ٖ\n",
110
+ "....c ظ\n",
111
+ "....c 3\n",
112
+ "....c :\n",
113
+ "....c ۟\n",
114
+ "....c ٥\n",
115
+ "....c چ\n",
116
+ "....c ٣\n",
117
+ "....c ,\n",
118
+ "....c ٧\n",
119
+ "....c ﮐ\n",
120
+ "....c {\n",
121
+ "....c 9\n",
122
+ "....c ?\n",
123
+ "....c '\n",
124
+ "....c ْ\n",
125
+ "....c *\n",
126
+ "....c ـ\n",
127
+ "....c ٔ\n",
128
+ "....c #\n",
129
+ "....c ٓ\n",
130
+ "....c ِ\n",
131
+ "....c 1\n",
132
+ "....c 6\n",
133
+ "....c ‘\n",
134
+ "....c (\n",
135
+ "....c ٠\n",
136
+ "....c ٞ\n",
137
+ "....c ٯ\n",
138
+ "....c ؤ\n",
139
+ "....c ٘\n",
140
+ "....c ٟ\n",
141
+ "....c ٴ\n",
142
+ "....c ݘ\n",
143
+ "....c ٫\n",
144
+ "....c ی\n",
145
+ "....c ٦\n",
146
+ "....c ٩\n",
147
+ "....c ٵ\n",
148
+ "....c ٱ\n",
149
+ "....c –\n",
150
+ "....c ؛\n",
151
+ "....c ٶ\n",
152
+ "....c ٭\n",
153
+ "....c ٗ\n",
154
+ "....c ﭐ\n",
155
+ "....c �\n",
156
+ "....c ﺟ\n",
157
+ "....c ﮞ\n",
158
+ "ﮞ 866\n",
159
+ "� 876\n",
160
+ "ﭐ 888\n",
161
+ "ﮐ 892\n",
162
+ "ﺟ 910\n",
163
+ "۟ 2654\n",
164
+ "‘ 2685\n",
165
+ "ݘ 2718\n",
166
+ "ی 2790\n",
167
+ "– 2802\n",
168
+ "٥ 7794\n",
169
+ "ٴ 7866\n",
170
+ "ں 7885\n",
171
+ "٤ 7898\n",
172
+ "ڡ 7902\n",
173
+ "ٝ 7921\n",
174
+ "ٵ 7933\n",
175
+ "ٲ 7944\n",
176
+ "ٶ 7962\n",
177
+ "٬ 7966\n",
178
+ "ٳ 7971\n",
179
+ "٧ 8011\n",
180
+ "٭ 8012\n",
181
+ "٢ 8019\n",
182
+ "٘ 8045\n",
183
+ "٦ 8049\n",
184
+ "ٰ 8056\n",
185
+ "٠ 8061\n",
186
+ "ٟ 8067\n",
187
+ "ٙ 8080\n",
188
+ "٩ 8081\n",
189
+ "ٯ 8083\n",
190
+ "٪ 8084\n",
191
+ "٫ 8098\n",
192
+ "ک 8102\n",
193
+ "ٱ 8118\n",
194
+ "ٜ 8123\n",
195
+ "ڤ 8126\n",
196
+ "ٮ 8133\n",
197
+ "ٞ 8133\n",
198
+ "١ 8158\n",
199
+ "ٛ 8162\n",
200
+ "٨ 8163\n",
201
+ "ٚ 8194\n",
202
+ "چ 8219\n",
203
+ "ٗ 8222\n",
204
+ "٣ 8378\n",
205
+ "ٍ 12142\n",
206
+ "~ 12159\n",
207
+ "9 12171\n",
208
+ "ِ 12173\n",
209
+ "1 12198\n",
210
+ "ٓ 12228\n",
211
+ "[ 12238\n",
212
+ "{ 12281\n",
213
+ "' 12285\n",
214
+ "! 12317\n",
215
+ "× 12331\n",
216
+ "< 12337\n",
217
+ "2 12344\n",
218
+ "ْ 12345\n",
219
+ "_ 12349\n",
220
+ "- 12350\n",
221
+ "% 12360\n",
222
+ "8 12366\n",
223
+ "5 12373\n",
224
+ "3 12381\n",
225
+ "ٔ 12383\n",
226
+ "} 12384\n",
227
+ "# 12386\n",
228
+ "x 12392\n",
229
+ "ً 12392\n",
230
+ ": 12394\n",
231
+ "7 12398\n",
232
+ "* 12402\n",
233
+ "= 12404\n",
234
+ "+ 12431\n",
235
+ "> 12454\n",
236
+ "6 12457\n",
237
+ "ّ 12459\n",
238
+ "\\ 12461\n",
239
+ ") 12462\n",
240
+ "؛ 12467\n",
241
+ "` 12477\n",
242
+ "$ 12479\n",
243
+ "0 12486\n",
244
+ "؟ 12487\n",
245
+ "? 12487\n",
246
+ "ـ 12506\n",
247
+ ". 12514\n",
248
+ "( 12517\n",
249
+ "ٌ 12534\n",
250
+ "^ 12539\n",
251
+ "\" 12541\n",
252
+ "/ 12544\n",
253
+ "، 12565\n",
254
+ "ٖ 12566\n",
255
+ "َ 12568\n",
256
+ "ٕ 12574\n",
257
+ "; 12588\n",
258
+ "ُ 12604\n",
259
+ "& 12610\n",
260
+ "@ 12610\n",
261
+ "] 12644\n",
262
+ "4 12668\n",
263
+ ", 12673\n",
264
+ "| 12690\n",
265
+ "آ 13719\n",
266
+ "ؤ 17881\n",
267
+ "ظ 25831\n",
268
+ "غ 33164\n",
269
+ "ء 41411\n",
270
+ "ئ 58235\n",
271
+ "إ 62171\n",
272
+ "ث 63826\n",
273
+ "ذ 73869\n",
274
+ "ز 77302\n",
275
+ "ض 80181\n",
276
+ "ص 106844\n",
277
+ "ى 107009\n",
278
+ "خ 113804\n",
279
+ "ط 116166\n",
280
+ "ش 120758\n",
281
+ "أ 155544\n",
282
+ "ج 183706\n",
283
+ "ك 213684\n",
284
+ "ح 228902\n",
285
+ "ه 259392\n",
286
+ "ق 272453\n",
287
+ "ف 306028\n",
288
+ "س 308145\n",
289
+ "د 387589\n",
290
+ "ع 403225\n",
291
+ "ب 418369\n",
292
+ "ة 431012\n",
293
+ "ت 556063\n",
294
+ "ر 567321\n",
295
+ "ن 612221\n",
296
+ "و 639292\n",
297
+ "م 800866\n",
298
+ "ي 919298\n",
299
+ "ل 1496107\n",
300
+ "ا 2022160\n",
301
+ " 3392229\n",
302
+ "('Size:', 144)\n"
303
+ ]
304
+ }
305
+ ],
306
+ "source": [
307
+ "import sys\n",
308
+ "import json\n",
309
+ "import os\n",
310
+ "from collections import defaultdict\n",
311
+ "\n",
312
+ "# These are for RASAM full pages\n",
313
+ "#OUT_PATH_c = '/home/msaeed3/mehreen/datasets/RASAM/sfr/'\n",
314
+ "#DATA_PATH_c = '/home/msaeed3/mehreen/datasets/RASAM/sfr/'\n",
315
+ "#OUT_NAME_c = 'char_set_rasam.json'\n",
316
+ "\n",
317
+ "# These are for regions in RASAM and RASM\n",
318
+ "#OUT_PATH_c = '/home/msaeed3/mehreen/datasets/RASM/regions_sfr/'\n",
319
+ "#DATA_PATH_c = '/home/msaeed3/mehreen/datasets/RASM/regions_sfr/'\n",
320
+ "\n",
321
+ "# These are for RASM, RASAM, MoiseK, KHATT\n",
322
+ "#OUT_PATH_c = '/home/msaeed3/mehreen/datasets/arabic_all/'\n",
323
+ "#DATA_PATH_c = '/home/msaeed3/mehreen/datasets/arabic_all/'\n",
324
+ "#OUT_NAME_c = 'char_set_arabic.json'\n",
325
+ "\n",
326
+ "OUT_PATH_c = '/home/msaeed3/mehreen/datasets/synthetic/line_images/'\n",
327
+ "OUT_NAME_c = 'char_set_line_images.json'\n",
328
+ "DATA_PATH_c = '/home/msaeed3/mehreen/datasets/synthetic/line_images/'\n",
329
+ "\n",
330
+ "os.chdir('/home/msaeed3/mehreen/source/start_follow_read/py3/')\n",
331
+ "print(os.getcwd())\n",
332
+ "\n",
333
+ "def load_char_set(char_set_path):\n",
334
+ " with open(char_set_path) as f:\n",
335
+ " char_set = json.load(f)\n",
336
+ "\n",
337
+ " idx_to_char = {}\n",
338
+ " for k,v in char_set['idx_to_char'].items():\n",
339
+ " idx_to_char[int(k)] = v\n",
340
+ "\n",
341
+ " return idx_to_char, char_set['char_to_idx']\n",
342
+ "\n",
343
+ "if __name__ == \"__main__\":\n",
344
+ " character_set_path = OUT_PATH_c + OUT_NAME_c\n",
345
+ " out_char_to_idx = {}\n",
346
+ " out_idx_to_char = {}\n",
347
+ " char_freq = defaultdict(int) \n",
348
+ " \n",
349
+ " dirs = [DATA_PATH_c]\n",
350
+ " \n",
351
+ " input_data_files = ['Train.json', 'Valid.json', 'Test.json']\n",
352
+ " data_file = []\n",
353
+ " for path in dirs:\n",
354
+ " for i in range(len(input_data_files)):\n",
355
+ " data_file = path + input_data_files[i]\n",
356
+ " with open(data_file) as f:\n",
357
+ " paths = json.load(f)\n",
358
+ "\n",
359
+ " for json_path, image_path in paths:\n",
360
+ " \n",
361
+ " with open(json_path) as f:\n",
362
+ " data = json.load(f)\n",
363
+ "\n",
364
+ " cnt = 1 # this is important that this starts at 1 not 0\n",
365
+ " for data_item in data:\n",
366
+ " # Mehreen: Cater for Nan gt\n",
367
+ " if 'gt' in data_item and type(data_item['gt']) == float:\n",
368
+ " continue\n",
369
+ " for c in data_item.get('gt', None):\n",
370
+ "\n",
371
+ " if c is None:\n",
372
+ " print(\"There was a None GT\")\n",
373
+ " continue\n",
374
+ " if c not in out_char_to_idx:\n",
375
+ " print('....c',c)\n",
376
+ " out_char_to_idx[c] = cnt\n",
377
+ " out_idx_to_char[cnt] = c\n",
378
+ " cnt += 1\n",
379
+ " char_freq[c] += 1\n",
380
+ "\n",
381
+ "\n",
382
+ " out_char_to_idx2 = {}\n",
383
+ " out_idx_to_char2 = {}\n",
384
+ "\n",
385
+ " for i, c in enumerate(sorted(out_char_to_idx.keys())):\n",
386
+ " out_char_to_idx2[c] = i+1\n",
387
+ " out_idx_to_char2[i+1] = c\n",
388
+ "\n",
389
+ " output_data = {\n",
390
+ " \"char_to_idx\": out_char_to_idx2,\n",
391
+ " \"idx_to_char\": out_idx_to_char2\n",
392
+ " }\n",
393
+ "\n",
394
+ " for k,v in sorted(iter(char_freq.items()), key=lambda x: x[1]):\n",
395
+ " print(k, v)\n",
396
+ "\n",
397
+ " print((\"Size:\", len(output_data['char_to_idx'])))\n",
398
+ " \n",
399
+ " with open(character_set_path, 'w') as outfile:\n",
400
+ " json.dump(output_data, outfile)\n",
401
+ "\n",
402
+ " "
403
+ ]
404
+ },
405
+ {
406
+ "cell_type": "code",
407
+ "execution_count": null,
408
+ "id": "6337a429",
409
+ "metadata": {},
410
+ "outputs": [],
411
+ "source": [
412
+ "import os\n",
413
+ "import json\n",
414
+ "DATA_PATH_c = '/home/msaeed3/mehreen/datasets/arabic_all/'\n",
415
+ "OUT_NAME_c = 'char_set_arabic.json'\n",
416
+ "char_set_path = os.path.join(DATA_PATH_c, OUT_NAME_c)\n",
417
+ "\n",
418
+ "with open(char_set_path) as f:\n",
419
+ " char_set = json.load(f)\n",
420
+ "\n",
421
+ "larger_set = list(char_set['char_to_idx'].keys())\n",
422
+ "larger_set.sort()\n",
423
+ "for l in larger_set:\n",
424
+ " print(l, ascii(l))\n"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": null,
430
+ "id": "8c54a65d",
431
+ "metadata": {},
432
+ "outputs": [],
433
+ "source": [
434
+ "DATA_PATH_c = '/home/msaeed3/mehreen/datasets/synthetic/line_images/'\n",
435
+ "OUT_NAME_c = 'char_set_line_images.json'\n",
436
+ "char_set_path = os.path.join(DATA_PATH_c, OUT_NAME_c)\n",
437
+ "\n",
438
+ "with open(char_set_path) as f:\n",
439
+ " char_set = json.load(f)\n",
440
+ "\n",
441
+ "smaller_set = list(char_set['char_to_idx'].keys())\n",
442
+ "smaller_set.sort()\n",
443
+ "for s in smaller_set:\n",
444
+ " print(s, ascii(s))"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": null,
450
+ "id": "085f5ac4",
451
+ "metadata": {},
452
+ "outputs": [],
453
+ "source": [
454
+ "\n"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": null,
460
+ "id": "34fc501d",
461
+ "metadata": {},
462
+ "outputs": [],
463
+ "source": [
464
+ "diff = larger_set.difference(smaller_set)"
465
+ ]
466
+ },
467
+ {
468
+ "cell_type": "code",
469
+ "execution_count": null,
470
+ "id": "ce121a50",
471
+ "metadata": {},
472
+ "outputs": [],
473
+ "source": [
474
+ "def get_special_ascii_set():\n",
475
+ " char_set = []\n",
476
+ " asccii_range = [[33, 64], [91, 96], [123, 126], \n",
477
+ " # Arabic-Indic digits\n",
478
+ " [ord('\\u0661'), ord('\\u0669')]]\n",
479
+ " for r in asccii_range:\n",
480
+ " for val in range(r[0], r[1]+1):\n",
481
+ " char_set.append(chr(val))\n",
482
+ " return char_set\n",
483
+ "\n",
484
+ "print((ord('\\u0661')))\n",
485
+ "special_ascii = get_special_ascii_set()\n",
486
+ "special_ascii.sort()\n",
487
+ "print(special_ascii)\n",
488
+ "special_ascii = set(special_ascii)"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": null,
494
+ "id": "3a89913c",
495
+ "metadata": {},
496
+ "outputs": [],
497
+ "source": [
498
+ "larger_set = set(larger_set)\n",
499
+ "smaller_set = set(smaller_set)\n",
500
+ "l_diff_ss = larger_set.difference(special_ascii.union(smaller_set))\n",
501
+ "l_diff_ss = (list(l_diff_ss))\n",
502
+ "l_diff_ss.sort()\n",
503
+ "for l in l_diff_ss:\n",
504
+ " print(l, ascii(l))"
505
+ ]
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "execution_count": null,
510
+ "id": "72f97354",
511
+ "metadata": {},
512
+ "outputs": [],
513
+ "source": [
514
+ "larger_set.difference(smaller_set)"
515
+ ]
516
+ }
517
+ ],
518
+ "metadata": {
519
+ "kernelspec": {
520
+ "display_name": "torch",
521
+ "language": "python",
522
+ "name": "torch"
523
+ },
524
+ "language_info": {
525
+ "codemirror_mode": {
526
+ "name": "ipython",
527
+ "version": 3
528
+ },
529
+ "file_extension": ".py",
530
+ "mimetype": "text/x-python",
531
+ "name": "python",
532
+ "nbconvert_exporter": "python",
533
+ "pygments_lexer": "ipython3",
534
+ "version": "3.9.13"
535
+ }
536
+ },
537
+ "nbformat": 4,
538
+ "nbformat_minor": 5
539
+ }
py3/utils/character_set.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import json
3
+ import os
4
+ from collections import defaultdict
5
+
6
+ def load_char_set(char_set_path):
7
+ with open(char_set_path) as f:
8
+ char_set = json.load(f)
9
+
10
+ idx_to_char = {}
11
+ for k,v in char_set['idx_to_char'].items():
12
+ idx_to_char[int(k)] = v
13
+
14
+ return idx_to_char, char_set['char_to_idx']
15
+
16
+ if __name__ == "__main__":
17
+ character_set_path = sys.argv[-1]
18
+ out_char_to_idx = {}
19
+ out_idx_to_char = {}
20
+ char_freq = defaultdict(int)
21
+ for i in range(1, len(sys.argv)-1):
22
+ data_file = sys.argv[i]
23
+ with open(data_file) as f:
24
+ paths = json.load(f)
25
+
26
+ for json_path, image_path in paths:
27
+ with open(json_path) as f:
28
+ data = json.load(f)
29
+
30
+ cnt = 1 # this is important that this starts at 1 not 0
31
+ for data_item in data:
32
+ for c in data_item.get('gt', None):
33
+ if c is None:
34
+ print("There was a None GT")
35
+ continue
36
+ if c not in out_char_to_idx:
37
+ out_char_to_idx[c] = cnt
38
+ out_idx_to_char[cnt] = c
39
+ cnt += 1
40
+ char_freq[c] += 1
41
+
42
+
43
+ out_char_to_idx2 = {}
44
+ out_idx_to_char2 = {}
45
+
46
+ for i, c in enumerate(sorted(out_char_to_idx.keys())):
47
+ out_char_to_idx2[c] = i+1
48
+ out_idx_to_char2[i+1] = c
49
+
50
+ output_data = {
51
+ "char_to_idx": out_char_to_idx2,
52
+ "idx_to_char": out_idx_to_char2
53
+ }
54
+
55
+ for k,v in sorted(iter(char_freq.items()), key=lambda x: x[1]):
56
+ print(k, v)
57
+
58
+ print(("Size:", len(output_data['char_to_idx'])))
59
+
60
+ with open(character_set_path, 'w') as outfile:
61
+ json.dump(output_data, outfile)
py3/utils/continuous_state.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from torch.autograd import Variable
5
+ from torch import nn
6
+
7
+ import sol
8
+ from sol.start_of_line_finder import StartOfLineFinder
9
+ from lf.line_follower import LineFollower
10
+ from hw import cnn_lstm
11
+
12
+ from utils import safe_load
13
+
14
+ import numpy as np
15
+ import cv2
16
+ import json
17
+ import sys
18
+ import os
19
+ import time
20
+ import random
21
+
22
+ def init_model(config, sol_dir='best_validation', lf_dir='best_validation', hw_dir='best_validation',
23
+ only_load=None, device="cuda"):
24
+
25
+
26
+
27
+ dtype = torch.FloatTensor
28
+ if 'cuda' in device:
29
+ dtype = torch.cuda.FloatTensor
30
+
31
+
32
+ base_0 = config['network']['sol']['base0']
33
+ base_1 = config['network']['sol']['base1']
34
+
35
+ sol = None
36
+ lf = None
37
+ hw = None
38
+
39
+ if only_load is None or only_load == 'sol' or 'sol' in only_load:
40
+ sol = StartOfLineFinder(base_0, base_1)
41
+ sol_state = safe_load.torch_state(os.path.join(config['snapshot_path'], "sol.pt"))
42
+ sol.load_state_dict(sol_state)
43
+ sol.to(device)
44
+
45
+ if only_load is None or only_load == 'lf' or 'lf' in only_load:
46
+ # This field may not be present in config and maybe added by the calling module...so you won't see it in the config file
47
+ pt_file = 'lf.pt'
48
+
49
+ lf = LineFollower(config['network']['hw']['input_height'], dtype=dtype, device=device)
50
+ lf_state = safe_load.torch_state(os.path.join(config['snapshot_path'], pt_file))
51
+
52
+
53
+ # special case for backward support of
54
+ # previous way to save the LF weights
55
+ if 'cnn' in lf_state:
56
+ new_state = {}
57
+ for k, v in lf_state.items():
58
+ print(k)
59
+ if k == 'cnn':
60
+ for k2, v2 in v.items():
61
+ if "running" in k2:
62
+ AAA=1
63
+ else:
64
+ new_state[k+"."+k2]=v2
65
+ if k == 'position_linear':
66
+ # print(k2, v2)
67
+ for k2, v2 in v.state_dict().items():
68
+ new_state[k+"."+k2]=v2
69
+ # if k == 'learned_window':
70
+ # print(k, v.data)
71
+ # new_state[k]=nn.Parameter(v.data)
72
+
73
+ lf_state = new_state
74
+
75
+
76
+
77
+ lf.load_state_dict(lf_state)
78
+
79
+ lf.to(device)
80
+
81
+ if only_load is None or only_load == 'hw' or 'hw' in only_load:
82
+ hw = cnn_lstm.create_model(config['network']['hw'])
83
+ hw_state = safe_load.torch_state(os.path.join(config['snapshot_path'], "hw.pt"))
84
+ hw.load_state_dict(hw_state)
85
+ hw.to(device)
86
+
87
+ return sol, lf, hw
py3/utils/dataset_parse.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ def load_file_list(config):
5
+ file_list_path = config['file_list']
6
+ with open(file_list_path) as f:
7
+ data = json.load(f)
8
+
9
+ for d in data:
10
+ # print("files:",d)
11
+ json_path = os.path.join(config['json_folder'], d[0])
12
+ img_path = os.path.join(config['img_folder'], d[1])
13
+
14
+ d[0] = json_path
15
+ d[1] = img_path
16
+
17
+ return data
py3/utils/dataset_wrapper.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class DatasetWrapper(object):
2
+ def __init__(self, dataset, count):
3
+ self.count = count
4
+ self.idx = 0
5
+ self.dataset = dataset
6
+ self.iter_dataset = iter(dataset)
7
+ self.epoch = 0
8
+
9
+ def __iter__(self):
10
+ return self
11
+
12
+ def __next__(self):
13
+ if self.idx >= self.count:
14
+ self.idx = 0
15
+ raise StopIteration
16
+
17
+ self.idx += 1
18
+ while True:
19
+ try:
20
+ return next(self.iter_dataset)
21
+ except StopIteration:
22
+ self.iter_dataset = iter(self.dataset)
23
+ self.epoch += 1
24
+ try:
25
+ return next(self.iter_dataset)
26
+ except StopIteration:
27
+ raise Exception("Appears as if dataset is empty")
py3/utils/error_rates.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import editdistance
2
+ def cer(r, h):
3
+ #Remove any double or trailing
4
+ r = ' '.join(r.split())
5
+ h = ' '.join(h.split())
6
+
7
+ return err(r, h)
8
+
9
+ def err(r, h):
10
+ dis = editdistance.eval(r, h)
11
+ if len(r) == 0.0:
12
+ return len(h)
13
+
14
+ # print(float(dis) / float(len(r)))
15
+ return float(dis) / float(len(r))
16
+
17
+ def wer(r, h):
18
+ r = r.split()
19
+ h = h.split()
20
+
21
+ return err(r,h)
py3/utils/fast_inverse.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from torch.autograd import Variable
3
+ import torch
4
+
5
+ def adjoint(A):
6
+ """compute inverse without division by det; ...xv3xc3 input, or array of matrices assumed"""
7
+ AI = np.empty_like(A)
8
+ for i in range(3):
9
+ AI[...,i,:] = np.cross(A[...,i-2,:], A[...,i-1,:])
10
+ return AI
11
+
12
+ def inverse_transpose(A):
13
+ """
14
+ efficiently compute the inverse-transpose for stack of 3x3 matrices
15
+ """
16
+ I = adjoint(A)
17
+ det = dot(I, A).mean(axis=-1)
18
+ return I / det[...,None,None]
19
+
20
+ def inverse(A):
21
+ """inverse of a stack of 3x3 matrices"""
22
+ return np.swapaxes( inverse_transpose(A), -1,-2)
23
+ def dot(A, B):
24
+ """dot arrays of vecs; contract over last indices"""
25
+ return np.einsum('...i,...i->...', A, B)
26
+
27
+ def adjoint_torch(A):
28
+ AI = A.clone()
29
+ for i in range(3):
30
+ AI[...,i,:] = torch.cross(A[...,i-2,:], A[...,i-1,:])
31
+ return AI
32
+
33
+ def inverse_transpose_torch(A):
34
+ I = adjoint_torch(A)
35
+ det = dot_torch(I, A).mean(dim=-1)
36
+ return I / det[:,None,None]
37
+
38
+ def inverse_torch(A):
39
+ return inverse_transpose_torch(A).transpose(1, 2)
40
+
41
+ def dot_torch(A, B):
42
+ A_view = A.view(-1,1,3)
43
+ B_view = B.contiguous().view(-1,3,1)
44
+ out = torch.bmm(A_view, B_view)
45
+ out_view = out.view(A.size()[:-1])
46
+ return out_view
47
+
48
+
49
+ if __name__ == "__main__":
50
+ A = np.random.rand(2,3,3)
51
+ I = inverse(A)
52
+
53
+ A_torch = Variable(torch.from_numpy(A))
54
+
55
+ I_torch = inverse_torch(A_torch)
56
+ print(I)
57
+ print(I_torch)
58
+
py3/utils/safe_load.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import json
4
+
5
+ def torch_state(path):
6
+ for i in range(10):
7
+ try:
8
+ state = torch.load(path)
9
+ return state
10
+ except:
11
+ print("Failed to load",i,path)
12
+ time.sleep(i)
13
+ pass
14
+
15
+ print("Failed to load state")
16
+ return
17
+
18
+ def json_state(path):
19
+ for i in range(10):
20
+ try:
21
+ with open(path) as f:
22
+ state = json.load(f)
23
+ return state
24
+ except:
25
+ print("Failed to load",i,path)
26
+ time.sleep(i)
27
+ pass
28
+
29
+ print("Failed to load state")
30
+ return None