Spaces:
Build error
Build error
Mehreen Saeed commited on
Commit ·
7fc513e
1
Parent(s): c309edf
added different device options
Browse files- arabic/page_htr.py +8 -7
- py3/utils/safe_load.py +2 -1
arabic/page_htr.py
CHANGED
|
@@ -228,7 +228,7 @@ def page_htr_one_file(img_file, config_file, model_mode="pretrain", device="cuda
|
|
| 228 |
return page_json
|
| 229 |
|
| 230 |
|
| 231 |
-
def hw_one_file(img_file, config_file, json_obj, model_mode="pretrain", line_key=None):
|
| 232 |
|
| 233 |
HW, idx_to_char = get_hw(config_file)
|
| 234 |
|
|
@@ -243,9 +243,9 @@ def hw_one_file(img_file, config_file, json_obj, model_mode="pretrain", line_key
|
|
| 243 |
img = cv2.imread(img_file)
|
| 244 |
line_img = warp.get_line_image(values['coord'], img)
|
| 245 |
line_text = test_hw.get_predicted_str(HW, None, idx_to_char, flip=True,
|
| 246 |
-
img=line_img, read_image=False)
|
| 247 |
-
|
| 248 |
-
|
| 249 |
line_text_logical_order = clean.get_clean_visual_order(line_text)
|
| 250 |
json_obj[line]['text'] = line_text_logical_order
|
| 251 |
|
|
@@ -269,13 +269,14 @@ if __name__ == "__main__":
|
|
| 269 |
|
| 270 |
args = parser.parse_args()
|
| 271 |
json_obj = {}
|
|
|
|
| 272 |
if args.line_htr == 1:
|
| 273 |
json_obj = json.loads(args.original_json)
|
| 274 |
json_obj = hw_one_file(args.img_path, args.config_file, json_obj,
|
| 275 |
-
model_mode="pretrain", line_key=args.line_key)
|
| 276 |
else:
|
| 277 |
-
json_obj = json.loads(args.original_json)
|
| 278 |
-
json_obj = page_htr_one_file(args.img_path, args.config_file, device=
|
| 279 |
|
| 280 |
print('BEGIN_OUT')
|
| 281 |
print(json.dumps(json_obj))
|
|
|
|
| 228 |
return page_json
|
| 229 |
|
| 230 |
|
| 231 |
+
def hw_one_file(img_file, config_file, json_obj, model_mode="pretrain", line_key=None, device="cuda"):
|
| 232 |
|
| 233 |
HW, idx_to_char = get_hw(config_file)
|
| 234 |
|
|
|
|
| 243 |
img = cv2.imread(img_file)
|
| 244 |
line_img = warp.get_line_image(values['coord'], img)
|
| 245 |
line_text = test_hw.get_predicted_str(HW, None, idx_to_char, flip=True,
|
| 246 |
+
img=line_img, read_image=False, device=device)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
line_text_logical_order = clean.get_clean_visual_order(line_text)
|
| 250 |
json_obj[line]['text'] = line_text_logical_order
|
| 251 |
|
|
|
|
| 269 |
|
| 270 |
args = parser.parse_args()
|
| 271 |
json_obj = {}
|
| 272 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 273 |
if args.line_htr == 1:
|
| 274 |
json_obj = json.loads(args.original_json)
|
| 275 |
json_obj = hw_one_file(args.img_path, args.config_file, json_obj,
|
| 276 |
+
model_mode="pretrain", line_key=args.line_key, device=device)
|
| 277 |
else:
|
| 278 |
+
json_obj = json.loads(args.original_json)
|
| 279 |
+
json_obj = page_htr_one_file(args.img_path, args.config_file, device=device)
|
| 280 |
|
| 281 |
print('BEGIN_OUT')
|
| 282 |
print(json.dumps(json_obj))
|
py3/utils/safe_load.py
CHANGED
|
@@ -11,7 +11,8 @@ def torch_state(path):
|
|
| 11 |
print('Warning: Model not found')
|
| 12 |
else:
|
| 13 |
print('Good: Model is found')
|
| 14 |
-
|
|
|
|
| 15 |
return state
|
| 16 |
except:
|
| 17 |
print("Failed to load",i,path)
|
|
|
|
| 11 |
print('Warning: Model not found')
|
| 12 |
else:
|
| 13 |
print('Good: Model is found')
|
| 14 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 15 |
+
state = torch.load(path, weights_only=True, map_location=device)
|
| 16 |
return state
|
| 17 |
except:
|
| 18 |
print("Failed to load",i,path)
|