Mehreen Saeed commited on
Commit
7fc513e
·
1 Parent(s): c309edf

added different device options

Browse files
Files changed (2) hide show
  1. arabic/page_htr.py +8 -7
  2. py3/utils/safe_load.py +2 -1
arabic/page_htr.py CHANGED
@@ -228,7 +228,7 @@ def page_htr_one_file(img_file, config_file, model_mode="pretrain", device="cuda
228
  return page_json
229
 
230
 
231
- def hw_one_file(img_file, config_file, json_obj, model_mode="pretrain", line_key=None):
232
 
233
  HW, idx_to_char = get_hw(config_file)
234
 
@@ -243,9 +243,9 @@ def hw_one_file(img_file, config_file, json_obj, model_mode="pretrain", line_key
243
  img = cv2.imread(img_file)
244
  line_img = warp.get_line_image(values['coord'], img)
245
  line_text = test_hw.get_predicted_str(HW, None, idx_to_char, flip=True,
246
- img=line_img, read_image=False)
247
- line_text = make_manual_text_correction(line_text)
248
- # New change
249
  line_text_logical_order = clean.get_clean_visual_order(line_text)
250
  json_obj[line]['text'] = line_text_logical_order
251
 
@@ -269,13 +269,14 @@ if __name__ == "__main__":
269
 
270
  args = parser.parse_args()
271
  json_obj = {}
 
272
  if args.line_htr == 1:
273
  json_obj = json.loads(args.original_json)
274
  json_obj = hw_one_file(args.img_path, args.config_file, json_obj,
275
- model_mode="pretrain", line_key=args.line_key)
276
  else:
277
- json_obj = json.loads(args.original_json)
278
- json_obj = page_htr_one_file(args.img_path, args.config_file, device="cpu")
279
 
280
  print('BEGIN_OUT')
281
  print(json.dumps(json_obj))
 
228
  return page_json
229
 
230
 
231
+ def hw_one_file(img_file, config_file, json_obj, model_mode="pretrain", line_key=None, device="cuda"):
232
 
233
  HW, idx_to_char = get_hw(config_file)
234
 
 
243
  img = cv2.imread(img_file)
244
  line_img = warp.get_line_image(values['coord'], img)
245
  line_text = test_hw.get_predicted_str(HW, None, idx_to_char, flip=True,
246
+ img=line_img, read_image=False, device=device)
247
+
248
+
249
  line_text_logical_order = clean.get_clean_visual_order(line_text)
250
  json_obj[line]['text'] = line_text_logical_order
251
 
 
269
 
270
  args = parser.parse_args()
271
  json_obj = {}
272
+ device = "cuda" if torch.cuda.is_available() else "cpu"
273
  if args.line_htr == 1:
274
  json_obj = json.loads(args.original_json)
275
  json_obj = hw_one_file(args.img_path, args.config_file, json_obj,
276
+ model_mode="pretrain", line_key=args.line_key, device=device)
277
  else:
278
+ json_obj = json.loads(args.original_json)
279
+ json_obj = page_htr_one_file(args.img_path, args.config_file, device=device)
280
 
281
  print('BEGIN_OUT')
282
  print(json.dumps(json_obj))
py3/utils/safe_load.py CHANGED
@@ -11,7 +11,8 @@ def torch_state(path):
11
  print('Warning: Model not found')
12
  else:
13
  print('Good: Model is found')
14
- state = torch.load(path, weights_only=True, map_location="cpu")
 
15
  return state
16
  except:
17
  print("Failed to load",i,path)
 
11
  print('Warning: Model not found')
12
  else:
13
  print('Good: Model is found')
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ state = torch.load(path, weights_only=True, map_location=device)
16
  return state
17
  except:
18
  print("Failed to load",i,path)