henryu commited on
Commit
cdf0320
·
1 Parent(s): 2a9ba6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -184
app.py CHANGED
@@ -1,184 +1,43 @@
1
- """
2
- A main training script.
3
- """
4
-
5
-
6
- # Copyright (c) Facebook, Inc. and its affiliates.
7
- import warnings
8
- warnings.filterwarnings('ignore') # never print matching warnings
9
- import logging
10
- import os
11
- from collections import OrderedDict
12
- import torch
13
- import uniperceiver.utils.comm as comm
14
- from uniperceiver.config import get_cfg, CfgNode
15
- from uniperceiver.engine import DefaultTrainer, default_argument_parser, default_setup, launch, build_engine, add_moe_arguments
16
-
17
- #!TODO re-implement hooks
18
- from uniperceiver.engine import hooks
19
- from uniperceiver.modeling import add_config
20
- from uniperceiver.utils.env import init_distributed_mode, check_dist_portfile
21
- try:
22
- import deepspeed
23
- DEEPSPEED_INSTALLED = True
24
- except:
25
- DEEPSPEED_INSTALLED = False
26
-
27
- import copy
28
-
29
- def add_data_prefix(cfg):
30
- # TODO: more flexible method
31
- data_dir = os.getenv("DATA_PATH", None)
32
- mapping_list = [
33
- [cfg.DATALOADER, 'FEATS_FOLDER', ['DATALOADER',]],
34
- [cfg.DATALOADER, 'ANNO_FOLDER', ['DATALOADER', ]],
35
- [cfg.DATALOADER, 'CLASS_NAME_FILE', ['DATALOADER', ]],
36
- [cfg.INFERENCE, 'VOCAB', ['INFERENCE', ]],
37
- [cfg.INFERENCE, 'VAL_ANNFILE', ['INFERENCE', ]],
38
- [cfg.INFERENCE, 'TEST_ANNFILE', ['INFERENCE',]],
39
- [cfg.MODEL, 'WEIGHTS', ['MODEL',]],
40
- ]
41
- whitelist = ["BERT", "CLIP", "CLIP_CAPTION"]
42
- if data_dir:
43
- for node, attr ,_ in mapping_list:
44
- if node[attr] != '' and not node[attr].startswith('.') and not node[attr].startswith('/') and not node[attr].startswith('work_dirs') and not node[attr].startswith('cluster') and not node[attr].startswith('s3://') and node[attr] not in whitelist:
45
- setattr(node, attr, os.path.join(data_dir, node[attr]))
46
- for task in cfg.TASKS:
47
- for _, item, key_list in mapping_list:
48
- config_tmp = task
49
- for key in key_list:
50
- if key in config_tmp:
51
- config_tmp = config_tmp[key]
52
- if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith('/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith('cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
53
- config_tmp[item] = os.path.join(data_dir, config_tmp[item])
54
-
55
- mapping_list = [
56
- ['', 'FILE_PATH', ['SHARED_TARGETS_CFG',]],
57
- ]
58
- if cfg.SHARED_TARGETS is None:
59
- cfg.SHARED_TARGETS = []
60
- for share_targets in cfg.SHARED_TARGETS:
61
- for _, item, key_list in mapping_list:
62
- config_tmp = share_targets
63
- for key in key_list:
64
- config_tmp = config_tmp[key]
65
- if item in config_tmp and config_tmp[item] != '' and not config_tmp[item].startswith('.') and not config_tmp[item].startswith(
66
- '/') and not config_tmp[item].startswith('work_dirs') and not config_tmp[item].startswith(
67
- 'cluster') and not config_tmp[item].startswith('s3://') and config_tmp[item] not in whitelist:
68
- config_tmp[item] = os.path.join(data_dir, config_tmp[item])
69
-
70
-
71
-
72
- def add_default_setting_for_multitask_config(cfg):
73
- # merge some default config in (CfgNode) uniperceiver/config/defaults.py to each task config (dict)
74
-
75
- tasks_config_temp = cfg.TASKS
76
- num_tasks = len(tasks_config_temp)
77
- cfg.pop('TASKS', None)
78
-
79
- cfg.TASKS = [copy.deepcopy(cfg) for _ in range(num_tasks)]
80
-
81
- for i, task_config in enumerate(tasks_config_temp):
82
- cfg.TASKS[i].merge_from_other_cfg(CfgNode(task_config))
83
- cfg.TASKS[i] = cfg.TASKS[i].to_dict_object()
84
- pass
85
-
86
-
87
- def setup(args):
88
- """
89
- Create configs and perform basic setups.
90
- """
91
- cfg = get_cfg()
92
- tmp_cfg = cfg.load_from_file_tmp(args.config_file)
93
- add_config(cfg, tmp_cfg)
94
-
95
- cfg.merge_from_file(args.config_file)
96
- add_data_prefix(cfg)
97
-
98
- cfg.merge_from_list(args.opts)
99
- #
100
- add_default_setting_for_multitask_config(cfg)
101
- cfg.freeze()
102
- default_setup(cfg, args)
103
- return cfg
104
-
105
- def main(args):
106
- cfg = setup(args)
107
-
108
- """
109
- If you'd like to do anything fancier than the standard training logic,
110
- consider writing your own training loop (see plain_train_net.py) or
111
- subclassing the trainer.
112
- """
113
- trainer = build_engine(cfg)
114
- trainer.resume_or_load(resume=args.resume)
115
- trainer.cast_layers()
116
-
117
- if args.eval_only:
118
- print('---------------------------')
119
- print('eval model only')
120
- print('---------------------------\n')
121
- res = None
122
- if trainer.val_data_loader is not None:
123
-
124
- if trainer.model_ema is not None and args.eval_ema:
125
- if comm.is_main_process():
126
- print('using ema model for evaluation')
127
- res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
128
- else:
129
- if args.eval_ema and comm.is_main_process():
130
- print('no ema model exists! using master model for evaluation')
131
- res = trainer.test(trainer.cfg, trainer.model, trainer.val_data_loader, trainer.val_evaluator, epoch=-1)
132
-
133
- if comm.is_main_process():
134
- print(res)
135
-
136
- if trainer.test_data_loader is not None:
137
- if trainer.model_ema is not None and args.eval_ema:
138
- if comm.is_main_process():
139
- print('using ema model for evaluation')
140
- res = trainer.test(trainer.cfg, trainer.model_ema.ema, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
141
- else:
142
- if args.eval_ema and comm.is_main_process():
143
- print('no ema model exists! using master model for evaluation')
144
- res = trainer.test(trainer.cfg, trainer.model, trainer.test_data_loader, trainer.test_evaluator, epoch=-1)
145
- if comm.is_main_process():
146
- print(res)
147
- return res
148
-
149
- return trainer.train()
150
-
151
- def get_args_parser():
152
- parser = default_argument_parser()
153
- if DEEPSPEED_INSTALLED:
154
- parser = deepspeed.add_config_arguments(parser)
155
- parser = add_moe_arguments(parser)
156
-
157
- parser.add_argument('--init_method', default='slurm', type=str)
158
- parser.add_argument('--local_rank', default=0, type=int)
159
- parser.add_argument("--eval-ema", action="store_true", help="perform evaluation using ema")
160
- args = parser.parse_args()
161
-
162
- return args
163
-
164
- if __name__ == "__main__":
165
- args = get_args_parser()
166
- print("Command Line Args:", args)
167
- if args.init_method == 'slurm':
168
- # slurm init
169
- check_dist_portfile()
170
- init_distributed_mode(args)
171
- main(args)
172
- elif args.init_method == 'pytorch':
173
- main(args)
174
- else:
175
- # follow 'd2' use default `mp.spawn` to init dist training
176
- print('using \'mp.spawn\' for dist init! ')
177
- launch(
178
- main,
179
- args.num_gpus,
180
- num_machines=args.num_machines,
181
- machine_rank=args.machine_rank,
182
- dist_url=args.dist_url,
183
- args=(args,),
184
- )
 
1
+ from codecs import encode, decode
2
+ import requests
3
+ import gradio as gr
4
+
5
+
6
+ def infer(im):
7
+ im.save("converted.png")
8
+ url = "https://ajax.thehive.ai/api/demo/classify?endpoint=text_recognition"
9
+ files = {
10
+ "image": ("converted.png", open("converted.png", "rb"), "image/png"),
11
+ "model_type": (None, "detection"),
12
+ "media_type": (None, "photo"),
13
+ }
14
+ headers = {"referer": "https://thehive.ai/"}
15
+
16
+ res = requests.post(url, headers=headers, files=files)
17
+
18
+ text = ""
19
+ blocks = []
20
+ for output in res.json()["response"]["output"]:
21
+ text += output["block_text"]
22
+ for poly in output["bounding_poly"]:
23
+ blocks.append(
24
+ {
25
+ "text": "".join([c["class"] for c in poly["classes"]]),
26
+ "rect": poly["dimensions"],
27
+ }
28
+ )
29
+
30
+ text = decode(encode(text, "latin-1", "backslashreplace"), "unicode-escape")
31
+
32
+ return text, blocks
33
+
34
+
35
+ iface = gr.Interface(
36
+ fn=infer,
37
+ title="Hive OCR simple",
38
+ description="Demo for Hive OCR. Transcribe and analyze media depicting typed, written, or graphic text",
39
+ inputs=[gr.Image(type="pil")],
40
+ outputs=["text", "json"],
41
+ examples=["20131216170659.jpg"],
42
+ article='<a href="https://thehive.ai/hive-ocr-solutions">Hive OCR</a>',
43
+ ).launch()