sayed555 commited on
Commit
de2b10e
·
verified ·
1 Parent(s): 82f073c

Upload 48 files

Browse files
Files changed (48) hide show
  1. utils/__init__.py +0 -0
  2. utils/__pycache__/__init__.cpython-310.pyc +0 -0
  3. utils/__pycache__/__init__.cpython-39.pyc +0 -0
  4. utils/__pycache__/config.cpython-310.pyc +0 -0
  5. utils/__pycache__/download_util.cpython-310.pyc +0 -0
  6. utils/__pycache__/exceptions.cpython-310.pyc +0 -0
  7. utils/__pycache__/fontformat.cpython-310.pyc +0 -0
  8. utils/__pycache__/imgproc_utils.cpython-310.pyc +0 -0
  9. utils/__pycache__/io_utils.cpython-310.pyc +0 -0
  10. utils/__pycache__/logger.cpython-310.pyc +0 -0
  11. utils/__pycache__/message.cpython-310.pyc +0 -0
  12. utils/__pycache__/proj_imgtrans.cpython-310.pyc +0 -0
  13. utils/__pycache__/registry.cpython-310.pyc +0 -0
  14. utils/__pycache__/shared.cpython-310.pyc +0 -0
  15. utils/__pycache__/shared.cpython-39.pyc +0 -0
  16. utils/__pycache__/split_text_region.cpython-310.pyc +0 -0
  17. utils/__pycache__/stroke_width_calculator.cpython-310.pyc +0 -0
  18. utils/__pycache__/structures.cpython-310.pyc +0 -0
  19. utils/__pycache__/text_layout.cpython-310.pyc +0 -0
  20. utils/__pycache__/text_processing.cpython-310.pyc +0 -0
  21. utils/__pycache__/textblock.cpython-310.pyc +0 -0
  22. utils/__pycache__/textblock_mask.cpython-310.pyc +0 -0
  23. utils/__pycache__/textlines_merge.cpython-310.pyc +0 -0
  24. utils/__pycache__/watermark_utils.cpython-310.pyc +0 -0
  25. utils/__pycache__/zluda_config.cpython-310.pyc +0 -0
  26. utils/appinfo.py +2 -0
  27. utils/config.py +287 -0
  28. utils/download_util.py +371 -0
  29. utils/exceptions.py +20 -0
  30. utils/fontformat.py +136 -0
  31. utils/imgproc_utils.py +413 -0
  32. utils/io_utils.py +243 -0
  33. utils/logger.py +99 -0
  34. utils/message.py +67 -0
  35. utils/package.py +289 -0
  36. utils/proj_imgtrans.py +720 -0
  37. utils/registry.py +272 -0
  38. utils/shared.py +160 -0
  39. utils/split_text_region.py +386 -0
  40. utils/stroke_width_calculator.py +113 -0
  41. utils/structures.py +84 -0
  42. utils/text_layout.py +477 -0
  43. utils/text_processing.py +237 -0
  44. utils/textblock.py +908 -0
  45. utils/textblock_mask.py +394 -0
  46. utils/textlines_merge.py +568 -0
  47. utils/watermark_utils.py +68 -0
  48. utils/zluda_config.py +32 -0
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (154 Bytes). View file
 
utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (133 Bytes). View file
 
utils/__pycache__/config.cpython-310.pyc ADDED
Binary file (11.2 kB). View file
 
utils/__pycache__/download_util.cpython-310.pyc ADDED
Binary file (9.14 kB). View file
 
utils/__pycache__/exceptions.cpython-310.pyc ADDED
Binary file (1.12 kB). View file
 
utils/__pycache__/fontformat.cpython-310.pyc ADDED
Binary file (5.73 kB). View file
 
utils/__pycache__/imgproc_utils.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
utils/__pycache__/io_utils.cpython-310.pyc ADDED
Binary file (7.77 kB). View file
 
utils/__pycache__/logger.cpython-310.pyc ADDED
Binary file (2.96 kB). View file
 
utils/__pycache__/message.cpython-310.pyc ADDED
Binary file (2.23 kB). View file
 
utils/__pycache__/proj_imgtrans.cpython-310.pyc ADDED
Binary file (22.1 kB). View file
 
utils/__pycache__/registry.cpython-310.pyc ADDED
Binary file (8.16 kB). View file
 
utils/__pycache__/shared.cpython-310.pyc ADDED
Binary file (4.14 kB). View file
 
utils/__pycache__/shared.cpython-39.pyc ADDED
Binary file (4.08 kB). View file
 
utils/__pycache__/split_text_region.cpython-310.pyc ADDED
Binary file (9.84 kB). View file
 
utils/__pycache__/stroke_width_calculator.cpython-310.pyc ADDED
Binary file (3.49 kB). View file
 
utils/__pycache__/structures.cpython-310.pyc ADDED
Binary file (2.84 kB). View file
 
utils/__pycache__/text_layout.cpython-310.pyc ADDED
Binary file (9.26 kB). View file
 
utils/__pycache__/text_processing.cpython-310.pyc ADDED
Binary file (5.42 kB). View file
 
utils/__pycache__/textblock.cpython-310.pyc ADDED
Binary file (26.8 kB). View file
 
utils/__pycache__/textblock_mask.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
utils/__pycache__/textlines_merge.cpython-310.pyc ADDED
Binary file (19 kB). View file
 
utils/__pycache__/watermark_utils.cpython-310.pyc ADDED
Binary file (1.65 kB). View file
 
utils/__pycache__/zluda_config.cpython-310.pyc ADDED
Binary file (1.13 kB). View file
 
utils/appinfo.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ branch = 'dev'
2
+ version = '1.4.0'
utils/config.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, os, traceback
2
+ import os.path as osp
3
+ import copy
4
+
5
+ from . import shared
6
+ from .fontformat import FontFormat
7
+ from .structures import List, Dict, Config, field, nested_dataclass
8
+ from .logger import logger as LOGGER
9
+ from .io_utils import json_dump_nested_obj, np, serialize_np
10
+
11
+
12
+ @nested_dataclass
13
+ class ModuleConfig(Config):
14
+ textdetector: str = 'ctd'
15
+ ocr: str = "mit48px"
16
+ inpainter: str = 'lama_large_512px'
17
+ translator: str = "google"
18
+ enable_detect: bool = True
19
+ keep_exist_textlines: bool = False
20
+ enable_ocr: bool = True
21
+ enable_translate: bool = True
22
+ enable_inpaint: bool = True
23
+ textdetector_params: Dict = field(default_factory=lambda: dict())
24
+ ocr_params: Dict = field(default_factory=lambda: dict())
25
+ translator_params: Dict = field(default_factory=lambda: dict())
26
+ inpainter_params: Dict = field(default_factory=lambda: dict())
27
+ translate_source: str = '日本語'
28
+ translate_target: str = '简体中文'
29
+ check_need_inpaint: bool = True
30
+ load_model_on_demand: bool = False
31
+ empty_runcache: bool = False
32
+
33
+ def get_params(self, module_key: str, for_saving=False) -> dict:
34
+ d = self[module_key + '_params']
35
+ if not for_saving:
36
+ return d
37
+ sd = {}
38
+ for module_key, module_params in d.items():
39
+ if module_params is None:
40
+ continue
41
+ saving_module_params = {}
42
+ sd[module_key] = saving_module_params
43
+ for pk, pv in module_params.items():
44
+ if pk in {'description'}:
45
+ continue
46
+ if isinstance(pv, dict):
47
+ pv = pv['value']
48
+ saving_module_params[pk] = pv
49
+ return sd
50
+
51
+ def get_saving_params(self, to_dict=True):
52
+ params = copy.copy(self)
53
+ params.ocr_params = self.get_params('ocr', for_saving=True)
54
+ params.inpainter_params = self.get_params('inpainter', for_saving=True)
55
+ params.textdetector_params = self.get_params('textdetector', for_saving=True)
56
+ params.translator_params = self.get_params('translator', for_saving=True)
57
+ if to_dict:
58
+ return params.__dict__
59
+ return params
60
+
61
+ def stage_enabled(self, idx: int):
62
+ if idx == 0:
63
+ return self.enable_detect
64
+ elif idx == 1:
65
+ return self.enable_ocr
66
+ elif idx == 2:
67
+ return self.enable_translate
68
+ elif idx == 3:
69
+ return self.enable_inpaint
70
+ else:
71
+ raise Exception(f'not supported stage idx: {idx}')
72
+
73
+ def all_stages_disabled(self):
74
+ return (self.enable_detect or self.enable_ocr or self.enable_translate or self.enable_inpaint) is False
75
+
76
+
77
+ @nested_dataclass
78
+ class DrawPanelConfig(Config):
79
+ pentool_color: List = field(default_factory=lambda: [0, 0, 0])
80
+ pentool_width: float = 30.
81
+ pentool_shape: int = 0
82
+ inpainter_width: float = 30.
83
+ inpainter_shape: int = 0
84
+ current_tool: int = 0
85
+ rectool_auto: bool = False
86
+ rectool_method: int = 0
87
+ recttool_dilate_ksize: int = 0
88
+
89
+ @nested_dataclass
90
+ class ProgramConfig(Config):
91
+
92
+ module: ModuleConfig = field(default_factory=lambda: ModuleConfig())
93
+ drawpanel: DrawPanelConfig = field(default_factory=lambda: DrawPanelConfig())
94
+ global_fontformat: FontFormat = field(default_factory=lambda: FontFormat())
95
+ recent_proj_list: List = field(default_factory=lambda: list())
96
+ show_page_list: bool = False
97
+ imgtrans_paintmode: bool = False
98
+ imgtrans_textedit: bool = True
99
+ imgtrans_textblock: bool = True
100
+ auto_watermark: bool = False
101
+ mask_transparency: float = 0.
102
+ original_transparency: float = 0.
103
+ open_recent_on_startup: bool = True
104
+ let_fntsize_flag: int = 0
105
+ let_fntstroke_flag: int = 0
106
+ let_fntcolor_flag: int = 0
107
+ let_fnt_scolor_flag: int = 0
108
+ let_fnteffect_flag: int = 1
109
+ let_alignment_flag: int = 0
110
+ let_writing_mode_flag: int = 0
111
+ let_family_flag: int = 0
112
+ let_autolayout_flag: bool = True
113
+ let_uppercase_flag: bool = True
114
+ let_show_only_custom_fonts_flag: bool = False
115
+ let_textstyle_indep_flag: bool = False
116
+ text_styles_path: str = osp.join(shared.DEFAULT_TEXTSTYLE_DIR, 'default.json')
117
+ fsearch_case: bool = False
118
+ fsearch_whole_word: bool = False
119
+ fsearch_regex: bool = False
120
+ fsearch_range: int = 0
121
+ gsearch_case: bool = False
122
+ gsearch_whole_word: bool = False
123
+ gsearch_regex: bool = False
124
+ gsearch_range: int = 0
125
+ darkmode: bool = False
126
+ textselect_mini_menu: bool = True
127
+ fold_textarea: bool = False
128
+ show_source_text: bool = True
129
+ show_trans_text: bool = True
130
+ saladict_shortcut: str = "Alt+S"
131
+ search_url: str = "https://www.google.com/search?q="
132
+ ocr_sublist: List = field(default_factory=lambda: list())
133
+ restore_ocr_empty: bool = False
134
+ pre_mt_sublist: List = field(default_factory=lambda: list())
135
+ mt_sublist: List = field(default_factory=lambda: list())
136
+ display_lang: str = field(default_factory=lambda: shared.DEFAULT_DISPLAY_LANG) # to always apply shared.DEFAULT_DISPLAY_LANG
137
+ imgsave_quality: int = 100
138
+ imgsave_ext: str = '.png'
139
+ intermediate_imgsave_ext: str = '.png'
140
+ show_text_style_preset: bool = True
141
+ expand_tstyle_panel: bool = True
142
+ show_text_effect_panel: bool = True
143
+ expand_teffect_panel: bool = True
144
+ text_advanced_format_panel: bool = True
145
+ expand_tadvanced_panel: bool = True
146
+ # Watermark settings
147
+ watermark_enabled: bool = False
148
+ watermark_path: str = ''""
149
+ watermark_opacity: int = 0.7
150
+ @staticmethod
151
+ def load(cfg_path: str):
152
+
153
+ with open(cfg_path, 'r', encoding='utf8') as f:
154
+ config_dict = json.loads(f.read())
155
+
156
+ # for backward compatibility
157
+ if 'dl' in config_dict:
158
+ dl = config_dict.pop('dl')
159
+ if not 'module' in config_dict:
160
+ if 'textdetector_setup_params' in dl:
161
+ textdetector_params = dl.pop('textdetector_setup_params')
162
+ dl['textdetector_params'] = textdetector_params
163
+ if 'inpainter_setup_params' in dl:
164
+ inpainter_params = dl.pop('inpainter_setup_params')
165
+ dl['inpainter_params'] = inpainter_params
166
+ if 'ocr_setup_params' in dl:
167
+ ocr_params = dl.pop('ocr_setup_params')
168
+ dl['ocr_params'] = ocr_params
169
+ if 'translator_setup_params' in dl:
170
+ translator_params = dl.pop('translator_setup_params')
171
+ dl['translator_params'] = translator_params
172
+ config_dict['module'] = dl
173
+
174
+ if 'module' in config_dict:
175
+ module_cfg = config_dict['module']
176
+ trans_params = module_cfg['translator_params']
177
+ repl_pairs = {'baidu': 'Baidu', 'caiyun': 'Caiyun', 'chatgpt': 'ChatGPT', 'Deepl': 'DeepL', 'papago': 'Papago'}
178
+ for k, i in repl_pairs.items():
179
+ if k in trans_params:
180
+ trans_params[i] = trans_params.pop(k)
181
+ if module_cfg['translator'] in repl_pairs:
182
+ module_cfg['translator'] = repl_pairs[module_cfg['translator']]
183
+
184
+ return ProgramConfig(**config_dict)
185
+
186
+
187
+ pcfg: ProgramConfig = None
188
+ text_styles: List[FontFormat] = []
189
+ active_format: FontFormat = None
190
+
191
+ def load_textstyle_from(p: str, raise_exception = False):
192
+
193
+ if not osp.exists(p):
194
+ LOGGER.warning(f'Text style {p} does not exist.')
195
+ return
196
+
197
+ try:
198
+ with open(p, 'r', encoding='utf8') as f:
199
+ style_list = json.loads(f.read())
200
+ styles_loaded = []
201
+ for style in style_list:
202
+ try:
203
+ styles_loaded.append(FontFormat(**style))
204
+ except Exception as e:
205
+ LOGGER.warning(f'Skip invalid text style: {style}')
206
+ except Exception as e:
207
+ LOGGER.error(f'Failed to load text style from {p}: {e}')
208
+ if raise_exception:
209
+ raise e
210
+ return
211
+
212
+ global text_styles, pcfg
213
+ if len(text_styles) > 0:
214
+ text_styles.clear()
215
+ text_styles.extend(styles_loaded)
216
+ pcfg.text_styles_path = p
217
+
218
+ def load_config(config_path: str = shared.CONFIG_PATH):
219
+ if config_path != shared.CONFIG_PATH:
220
+ shared.CONFIG_PATH = config_path
221
+ LOGGER.info(f'Using specified config file at {shared.CONFIG_PATH}')
222
+
223
+ if osp.exists(shared.CONFIG_PATH):
224
+ try:
225
+ config = ProgramConfig.load(shared.CONFIG_PATH)
226
+ except Exception as e:
227
+ LOGGER.exception(e)
228
+ LOGGER.warning("Failed to load config file, using default config")
229
+ config = ProgramConfig()
230
+ else:
231
+ LOGGER.info(f'{shared.CONFIG_PATH} does not exist, new config file will be created.')
232
+ config = ProgramConfig()
233
+
234
+ global pcfg
235
+ pcfg = config
236
+
237
+ p = pcfg.text_styles_path
238
+ if not osp.exists(pcfg.text_styles_path):
239
+ dp = osp.join(shared.DEFAULT_TEXTSTYLE_DIR, 'default.json')
240
+ if p != dp and osp.exists(dp):
241
+ p = dp
242
+ LOGGER.warning(f'Text style {p} does not exist, use the default from {dp}.')
243
+ else:
244
+ with open(dp, 'w', encoding='utf8') as f:
245
+ f.write(json.dumps([], ensure_ascii=False))
246
+ LOGGER.info(f'New text style file created at {dp}.')
247
+ load_textstyle_from(p)
248
+
249
+
250
+ def json_dump_program_config(obj, **kwargs):
251
+ def _default(obj):
252
+ if isinstance(obj, (np.ndarray, np.ScalarType)):
253
+ return serialize_np(obj)
254
+ elif isinstance(obj, ModuleConfig):
255
+ return obj.get_saving_params()
256
+ return obj.__dict__
257
+ return json.dumps(obj, default=lambda o: _default(o), ensure_ascii=False, **kwargs)
258
+
259
+
260
+ def save_config():
261
+ global pcfg
262
+ try:
263
+ with open(shared.CONFIG_PATH, 'w', encoding='utf8') as f:
264
+ f.write(json_dump_program_config(pcfg))
265
+ LOGGER.info('Config saved')
266
+ return True
267
+ except Exception as e:
268
+ LOGGER.error(f'Failed save config to {shared.CONFIG_PATH}: {e}')
269
+ LOGGER.error(traceback.format_exc())
270
+ return False
271
+
272
+ def save_text_styles(raise_exception = False):
273
+ global pcfg, text_styles
274
+ try:
275
+ style_dir = osp.dirname(pcfg.text_styles_path)
276
+ if not osp.exists(style_dir):
277
+ os.makedirs(style_dir)
278
+ with open(pcfg.text_styles_path, 'w', encoding='utf8') as f:
279
+ f.write(json_dump_nested_obj(text_styles))
280
+ LOGGER.info('Text style saved')
281
+ return True
282
+ except Exception as e:
283
+ LOGGER.error(f'Failed save text style to {pcfg.text_styles_path}: {e}')
284
+ LOGGER.error(traceback.format_exc())
285
+ if raise_exception:
286
+ raise e
287
+ return False
utils/download_util.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import requests
4
+ import traceback
5
+ import re
6
+ import sys
7
+ import shutil
8
+ import os.path as osp
9
+ from typing import List, Union
10
+ import hashlib
11
+
12
+ from tqdm import tqdm
13
+ from urllib.parse import urlparse
14
+ from torch.hub import download_url_to_file as _torchhub_download_url_to_file, get_dir
15
+ import requests
16
+ import tqdm
17
+ from py7zr import pack_7zarchive, unpack_7zarchive
18
+ import ssl
19
+
20
+ from . import shared
21
+ from .logger import logger as LOGGER
22
+
23
+ shutil.register_archive_format('7zip', pack_7zarchive, description='7zip archive')
24
+ shutil.register_unpack_format('7zip', ['.7z'], unpack_7zarchive)
25
+
26
+
27
+ def calculate_sha256(filename):
28
+ hash_sha256 = hashlib.sha256()
29
+ blksize = 1024 * 1024
30
+
31
+ with open(filename, "rb") as f:
32
+ for chunk in iter(lambda: f.read(blksize), b""):
33
+ hash_sha256.update(chunk)
34
+
35
+ return hash_sha256.hexdigest().lower()
36
+
37
+
38
+ def sizeof_fmt(size, suffix='B'):
39
+ """Get human readable file size.
40
+
41
+ Args:
42
+ size (int): File size.
43
+ suffix (str): Suffix. Default: 'B'.
44
+
45
+ Return:
46
+ str: Formatted file size.
47
+ """
48
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
49
+ if abs(size) < 1024.0:
50
+ return f'{size:3.1f} {unit}{suffix}'
51
+ size /= 1024.0
52
+ return f'{size:3.1f} Y{suffix}'
53
+
54
+
55
+ def download_file_from_google_drive(file_id, save_path):
56
+ """Download files from google drive.
57
+
58
+ Ref:
59
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
60
+
61
+ Args:
62
+ file_id (str): File id.
63
+ save_path (str): Save path.
64
+ """
65
+
66
+ session = requests.Session()
67
+ URL = 'https://docs.google.com/uc?export=download'
68
+ params = {'id': file_id, 'confirm': 't'} # https://stackoverflow.com/a/73893665/17671327
69
+
70
+ response = session.get(URL, params=params, stream=True)
71
+ token = get_confirm_token(response)
72
+ if token:
73
+ params['confirm'] = token
74
+ response = session.get(URL, params=params, stream=True)
75
+
76
+ # get file size
77
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
78
+ if 'Content-Range' in response_file_size.headers:
79
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
80
+ else:
81
+ file_size = None
82
+
83
+ save_response_content(response, save_path, file_size)
84
+
85
+
86
+ def get_confirm_token(response):
87
+ for key, value in response.cookies.items():
88
+ if key.startswith('download_warning'):
89
+ return value
90
+ return None
91
+
92
+
93
+ def save_response_content(response, destination, file_size=None, chunk_size=32768):
94
+ if file_size is not None:
95
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
96
+
97
+ readable_file_size = sizeof_fmt(file_size)
98
+ else:
99
+ pbar = None
100
+
101
+ with open(destination, 'wb') as f:
102
+ downloaded_size = 0
103
+ for chunk in response.iter_content(chunk_size):
104
+ downloaded_size += chunk_size
105
+ if pbar is not None:
106
+ pbar.update(1)
107
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
108
+ if chunk: # filter out keep-alive new chunks
109
+ f.write(chunk)
110
+ if pbar is not None:
111
+ pbar.close()
112
+
113
+ # def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
114
+ # """Load file form http url, will download models if necessary.
115
+
116
+ # Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
117
+
118
+ # Args:
119
+ # url (str): URL to be downloaded.
120
+ # model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
121
+ # Default: None.
122
+ # progress (bool): Whether to show the download progress. Default: True.
123
+ # file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
124
+
125
+ # Returns:
126
+ # str: The path to the downloaded file.
127
+ # """
128
+ # if model_dir is None: # use the pytorch hub_dir
129
+ # hub_dir = get_dir()
130
+ # model_dir = os.path.join(hub_dir, 'checkpoints')
131
+
132
+ # os.makedirs(model_dir, exist_ok=True)
133
+
134
+ # parts = urlparse(url)
135
+ # filename = os.path.basename(parts.path)
136
+ # if file_name is not None:
137
+ # filename = file_name
138
+ # cached_file = os.path.abspath(os.path.join(model_dir, filename))
139
+ # if not os.path.exists(cached_file):
140
+ # print(f'Downloading: "{url}" to {cached_file}\n')
141
+ # download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
142
+ # return cached_file
143
+
144
+
145
+ def torchhub_download_url_to_file(url: str, dst: str, progress: bool = True):
146
+ original_ctx = ssl._create_default_https_context
147
+ ssl._create_default_https_context = ssl._create_unverified_context # https://stackoverflow.com/questions/50236117/scraping-ssl-certificate-verify-failed-error-for-http-en-wikipedia-org
148
+ _torchhub_download_url_to_file(url, dst, progress=progress)
149
+ ssl._create_default_https_context = original_ctx
150
+
151
+ def check_local_file(local_file: str, sha256_precal: str = None, cache_hash: bool = False):
152
+
153
+ file_exists = osp.exists(local_file)
154
+ valid_hash, sha256_calculated = True, sha256_precal
155
+
156
+ if file_exists and sha256_precal is not None and shared.check_local_file_hash:
157
+ sha256_precal = sha256_precal.lower()
158
+ if cache_hash and local_file in shared.cache_data and shared.cache_data[local_file].lower() == sha256_precal:
159
+ pass
160
+ else:
161
+ sha256_calculated = calculate_sha256(local_file).lower()
162
+ if sha256_calculated != sha256_precal:
163
+ valid_hash = False
164
+ if cache_hash:
165
+ shared.cache_data[local_file] = sha256_calculated
166
+ shared.CACHE_UPDATED = True
167
+
168
+ return file_exists, valid_hash, sha256_calculated
169
+
170
+
171
+ def get_filename_from_url(url: str, default: str = '') -> str:
172
+ m = re.search(r'/([^/?]+)[^/]*$', url)
173
+ if m:
174
+ return m.group(1)
175
+ return default
176
+
177
+
178
+ def download_url_with_progressbar(url: str, path: str,):
179
+ if os.path.basename(path) in ('.', '') or os.path.isdir(path):
180
+ new_filename = get_filename_from_url(url)
181
+ if not new_filename:
182
+ raise Exception('Could not determine filename')
183
+ path = os.path.join(path, new_filename)
184
+
185
+ headers = {}
186
+ downloaded_size = 0
187
+ # the resume downloading here is buggy when the local file is corrupted or over-sized or intended to be replaced
188
+ # if os.path.isfile(path): # its actually buggy
189
+ # downloaded_size = os.path.getsize(path)
190
+ # headers['Range'] = 'bytes=%d-' % downloaded_size
191
+ # headers['Accept-Encoding'] = 'deflate'
192
+
193
+ r = requests.get(url, stream=True, allow_redirects=True, headers=headers)
194
+ if downloaded_size and r.headers.get('Accept-Ranges') != 'bytes':
195
+ print('Error: Webserver does not support partial downloads. Restarting from the beginning.')
196
+ r = requests.get(url, stream=True, allow_redirects=True)
197
+ downloaded_size = 0
198
+ total = int(r.headers.get('content-length', 0))
199
+ chunk_size = 1024
200
+
201
+ if r.ok:
202
+ with tqdm.tqdm(
203
+ desc=os.path.basename(path),
204
+ initial=downloaded_size,
205
+ total=total+downloaded_size,
206
+ unit='iB',
207
+ unit_scale=True,
208
+ unit_divisor=chunk_size,
209
+ ) as bar:
210
+ with open(path, 'ab' if downloaded_size else 'wb') as f:
211
+ is_tty = sys.stdout.isatty()
212
+ downloaded_chunks = 0
213
+ for data in r.iter_content(chunk_size=chunk_size):
214
+ size = f.write(data)
215
+ bar.update(size)
216
+
217
+ # Fallback for non TTYs so output still shown
218
+ downloaded_chunks += 1
219
+ if not is_tty and downloaded_chunks % 1000 == 0:
220
+ print(bar)
221
+ else:
222
+ raise Exception(f'Couldn\'t resolve url: "{url}" (Error: {r.status_code})')
223
+
224
+
225
+
226
+ def try_download_files(url: str,
227
+ files: List[str],
228
+ save_files = List[str],
229
+ sha256_pre_calculated: List[str] = None,
230
+ concatenate_url_filename: int = 0,
231
+ cache_hash: bool = False,
232
+ download_method: str = '',
233
+ gdrive_file_id: str = None):
234
+
235
+ all_successful = True
236
+
237
+ for file, savep, sha256_precal in zip(files, save_files, sha256_pre_calculated):
238
+ save_dir = osp.dirname(savep)
239
+ if not osp.exists(save_dir):
240
+ os.makedirs(save_dir)
241
+
242
+ file_exists, valid_hash, sha256_calculated = check_local_file(savep, sha256_precal, cache_hash=cache_hash)
243
+ if file_exists:
244
+ if valid_hash:
245
+ continue
246
+ else:
247
+ LOGGER.warning(f'Mismatch between local file {savep} and pre-calculated hash: "{sha256_calculated}" <-> "{sha256_precal.lower()}", it will be redownloaded...')
248
+
249
+ try:
250
+ if concatenate_url_filename == 1:
251
+ download_url = url + file
252
+ elif concatenate_url_filename == 2:
253
+ download_url = url + osp.basename(file)
254
+ else:
255
+ download_url = url
256
+
257
+ if gdrive_file_id is not None:
258
+ download_file_from_google_drive(gdrive_file_id, savep)
259
+ elif download_method == 'torch_hub':
260
+ LOGGER.info(f'downloading {savep} from {download_url} ...')
261
+ torchhub_download_url_to_file(download_url, savep)
262
+ else:
263
+ download_url_with_progressbar(download_url, savep)
264
+ file_exists, valid_hash, sha256_calculated = check_local_file(savep, sha256_precal, cache_hash=cache_hash)
265
+ if not file_exists:
266
+ raise Exception(f'Some how the downloaded {savep} doesnt exists.')
267
+ elif not valid_hash:
268
+ raise Exception(f'Mismatch between newly downloaded {savep} and pre-calculated hash: "{sha256_calculated}" <-> "{sha256_precal.lower()}"')
269
+
270
+ except:
271
+ err_msg = traceback.format_exc()
272
+ all_successful = False
273
+ LOGGER.error(err_msg)
274
+ LOGGER.error(f'Failed downloading {file} from {download_url}, please manually save it to {savep}')
275
+
276
+ return all_successful
277
+
278
+
279
+ def download_and_check_files(url: str,
280
+ files: Union[str, List],
281
+ save_files = None,
282
+ sha256_pre_calculated: Union[str, List] = None,
283
+ concatenate_url_filename: int = 0,
284
+ archived_files: List = None,
285
+ archive_sha256_pre_calculated: Union[str, List] = None,
286
+ save_dir: str = None,
287
+ download_method: str = 'torch_hub',
288
+ gdrive_file_id: str = None):
289
+
290
+ def _wrap_up_checkinputs(files: Union[str, List], save_files: Union[str, List] = None, sha256_pre_calculated: Union[str, List] = None, save_dir: str = None):
291
+ '''
292
+ ensure they're lists with equal length
293
+ '''
294
+ if not isinstance(files, List):
295
+ files = [files]
296
+ if not isinstance(sha256_pre_calculated, List):
297
+ if sha256_pre_calculated is None:
298
+ sha256_pre_calculated = [None] * len(files)
299
+ else:
300
+ sha256_pre_calculated = [sha256_pre_calculated]
301
+ if save_files is None:
302
+ save_files = files
303
+ elif not isinstance(save_files, List):
304
+ save_files = [save_files]
305
+
306
+ assert len(files) == len(sha256_pre_calculated) == len(save_files)
307
+
308
+ if save_dir is not None:
309
+ _save_files = []
310
+ for savep in save_files:
311
+ _save_files.append(osp.join(save_dir, savep))
312
+ save_files = _save_files
313
+
314
+ return files, save_files, sha256_pre_calculated
315
+
316
+ def _all_valid(save_files: List[str] = None, sha256_pre_calculated: List[str] = None,):
317
+ for savep, sha256_precal in zip(save_files, sha256_pre_calculated):
318
+ file_exists, valid_hash, sha256_calculated = check_local_file(savep, sha256_precal, cache_hash=True)
319
+ if not file_exists or not valid_hash:
320
+ return False
321
+ return True
322
+
323
+
324
+ files, save_files, sha256_pre_calculated = _wrap_up_checkinputs(files, save_files, sha256_pre_calculated, save_dir)
325
+
326
+ if archived_files is None:
327
+ return try_download_files(url, files, save_files, sha256_pre_calculated, concatenate_url_filename, cache_hash=True, download_method=download_method, gdrive_file_id=gdrive_file_id)
328
+
329
+ # handle archived
330
+ if _all_valid(save_files, sha256_pre_calculated):
331
+ return [], None
332
+
333
+ if isinstance(archived_files, str):
334
+ archived_files = [archived_files]
335
+
336
+ # download archive files
337
+ tmp_downloaded_archives = [osp.join(shared.cache_dir, archive_name) for archive_name in archived_files]
338
+ _, _, archive_sha256_pre_calculated = _wrap_up_checkinputs(archived_files, tmp_downloaded_archives, archive_sha256_pre_calculated)
339
+ archive_downloaded = try_download_files(url, archived_files, tmp_downloaded_archives, archive_sha256_pre_calculated, concatenate_url_filename, cache_hash=False, download_method=download_method, gdrive_file_id=gdrive_file_id)
340
+ if not archive_downloaded:
341
+ return False
342
+
343
+ # extract archived
344
+ archivep = tmp_downloaded_archives[0] # todo: support multi-volume
345
+ extract_dir = osp.join(shared.cache_dir, 'tmp_extract')
346
+ os.makedirs(extract_dir, exist_ok=True)
347
+ LOGGER.info(f'Extracting {archivep} ...')
348
+ shutil.unpack_archive(archivep, extract_dir)
349
+
350
+ all_valid = True
351
+ for file, savep, sha256_precal in zip(files, save_files, sha256_pre_calculated):
352
+ unarchived = osp.join(extract_dir, file)
353
+ save_dir = osp.dirname(savep)
354
+ if not osp.exists(save_dir):
355
+ os.makedirs(save_dir)
356
+ shutil.move(unarchived, savep)
357
+ file_exists, valid_hash, sha256_calculated = check_local_file(savep, sha256_precal, cache_hash=True)
358
+ if not file_exists:
359
+ LOGGER.error(f'The unarchived file {savep} doesnt exists.')
360
+ all_valid = False
361
+ elif not valid_hash:
362
+ LOGGER.error(f'Mismatch between the unarchived {savep} and pre-calculated hash: "{sha256_calculated}" <-> "{sha256_precal.lower()}"')
363
+ all_valid = False
364
+
365
+ if all_valid:
366
+ # clean archive files
367
+ shutil.rmtree(extract_dir)
368
+ for p in tmp_downloaded_archives:
369
+ os.remove(p)
370
+
371
+ return all_valid
utils/exceptions.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class ProjectDirNotExistException(Exception):
2
+ pass
3
+
4
+ class ProjectLoadFailureException(Exception):
5
+ pass
6
+
7
+ class ProjectNotSupportedException(Exception):
8
+ pass
9
+
10
+ class ImgnameNotInProjectException(Exception):
11
+ pass
12
+
13
+ class NotImplementedProjException(Exception):
14
+ pass
15
+
16
+ class InvalidModuleConfigException(Exception):
17
+ pass
18
+
19
+ class InvalidProgramConfigException(Exception):
20
+ pass
utils/fontformat.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+ import enum
3
+ import re
4
+ import copy
5
+
6
+ import numpy as np
7
+
8
+ from . import shared
9
+ from .structures import Tuple, Union, List, Dict, Config, field, nested_dataclass
10
+
11
+
12
+ def pt2px(pt, to_int=False) -> float:
13
+ if to_int:
14
+ return int(round(pt * shared.LDPI / 72.))
15
+ else:
16
+ return pt * shared.LDPI / 72.
17
+
18
+ def px2pt(px) -> float:
19
+ return px / shared.LDPI * 72.
20
+
21
+
22
+ class LineSpacingType(enum.IntEnum):
23
+ Proportional = 0
24
+ Distance = 1
25
+
26
+
27
+ class TextAlignment(enum.IntEnum):
28
+ Left = 0
29
+ Center = 1
30
+ Right = 2
31
+
32
+
33
+ fontweight_qt5_to_qt6 = {0: 100, 12: 200, 25: 300, 50: 400, 57: 500, 63: 600, 75: 700, 81: 800, 87: 900}
34
+ fontweight_qt6_to_qt5 = {100: 0, 200: 12, 300: 25, 400: 50, 500: 57, 600: 63, 700: 75, 800: 81, 900: 87}
35
+
36
+ fontweight_pattern = re.compile(r'font-weight:(\d+)', re.DOTALL)
37
+
38
+ def fix_fontweight_qt(weight: Union[str, int]):
39
+
40
+ def _fix_html_fntweight(matched):
41
+ weight = int(matched.group(1))
42
+ return f'font-weight:{fix_fontweight_qt(weight)}'
43
+
44
+ if weight is None:
45
+ return None
46
+ if isinstance(weight, int):
47
+ if shared.FLAG_QT6 and weight < 100:
48
+ if weight in fontweight_qt5_to_qt6:
49
+ weight = fontweight_qt5_to_qt6[weight]
50
+ if not shared.FLAG_QT6 and weight >= 100:
51
+ if weight in fontweight_qt6_to_qt5:
52
+ weight = fontweight_qt6_to_qt5[weight]
53
+ if isinstance(weight, str):
54
+ weight = fontweight_pattern.sub(lambda matched: _fix_html_fntweight(matched), weight)
55
+ return weight
56
+
57
+
58
+ @nested_dataclass
59
+ class FontFormat(Config):
60
+
61
+ font_family: str = shared.DEFAULT_FONT_FAMILY # to always apply shared.DEFAULT_FONT_FAMILY
62
+ font_size: float = 24
63
+ stroke_width: float = 0.
64
+ frgb: List = field(default_factory=lambda: [0, 0, 0])
65
+ srgb: List = field(default_factory=lambda: [0, 0, 0])
66
+ bold: bool = False
67
+ underline: bool = False
68
+ italic: bool = False
69
+ alignment: int = 0
70
+ vertical: bool = False
71
+ font_weight: int = None
72
+ line_spacing: float = 1.2
73
+ letter_spacing: float = 1.15
74
+ opacity: float = 1.
75
+ shadow_radius: float = 0.
76
+ shadow_strength: float = 1.
77
+ shadow_color: List = field(default_factory=lambda: [0, 0, 0])
78
+ shadow_offset: List = field(default_factory=lambda: [0., 0.])
79
+ gradient_enabled: bool = False
80
+ gradient_start_color: List = field(default_factory=lambda: [0, 0, 0])
81
+ gradient_end_color: List = field(default_factory=lambda: [255, 255, 255])
82
+ gradient_angle: float = 0.
83
+ gradient_size: float = 1.0
84
+ _style_name: str = ''
85
+ line_spacing_type: int = LineSpacingType.Proportional
86
+
87
+ deprecated_attributes: dict = field(default_factory = lambda: dict())
88
+
89
+ @property
90
+ def size_pt(self):
91
+ return px2pt(self.font_size)
92
+
93
+ def __post_init__(self):
94
+ da = self.deprecated_attributes
95
+ if len(da) > 0:
96
+ if 'size' in da:
97
+ self.font_size = pt2px(da['size'])
98
+ if 'weight' in da:
99
+ self.font_weight = da['weight']
100
+ if 'family' in da:
101
+ self.font_family = da['family']
102
+
103
+ self.font_weight = fix_fontweight_qt(self.font_weight)
104
+ self.deprecated_attributes = {}
105
+
106
+ def deepcopy(self):
107
+ fmt_copyed: FontFormat = None
108
+ fmt_copyed = copy.deepcopy(self)
109
+ return fmt_copyed
110
+
111
+ def merge(self, target: Config, compare: bool = False):
112
+ if id(self) == id(target):
113
+ return set()
114
+ tgt_keys = target.annotations_set()
115
+ updated_keys = set()
116
+ for key in tgt_keys:
117
+ if not hasattr(self, key):
118
+ continue
119
+ if compare:
120
+ if key != '_style_name':
121
+ if isinstance(target[key], np.ndarray):
122
+ is_diff = np.any(self[key] != target[key])
123
+ else:
124
+ is_diff = self[key] != target[key]
125
+ if is_diff:
126
+ self.update(key, copy.deepcopy(target[key]))
127
+ updated_keys.add(key)
128
+ else:
129
+ self.update(key, copy.deepcopy(target[key]))
130
+ return updated_keys
131
+
132
+ def foreground_color(self):
133
+ return [int(round(x)) for x in self.frgb]
134
+
135
+ def stroke_color(self):
136
+ return [int(round(x)) for x in self.srgb]
utils/imgproc_utils.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import random
4
+ from typing import List, Tuple, Union
5
+
6
+ def hex2bgr(hex):
7
+ gmask = 254 << 8
8
+ rmask = 254
9
+ b = hex >> 16
10
+ g = (hex & gmask) >> 8
11
+ r = hex & rmask
12
+ return np.stack([b, g, r]).transpose()
13
+
14
+ def union_area(bboxa, bboxb):
15
+ x1 = max(bboxa[0], bboxb[0])
16
+ y1 = max(bboxa[1], bboxb[1])
17
+ x2 = min(bboxa[2], bboxb[2])
18
+ y2 = min(bboxa[3], bboxb[3])
19
+ if y2 < y1 or x2 < x1:
20
+ return -1
21
+ return (y2 - y1) * (x2 - x1)
22
+
23
+ def get_yololabel_strings(clslist, labellist):
24
+ content = ''
25
+ for cls, xywh in zip(clslist, labellist):
26
+ content += str(int(cls)) + ' ' + ' '.join([str(e) for e in xywh]) + '\n'
27
+ if len(content) != 0:
28
+ content = content[:-1]
29
+ return content
30
+
31
+ # 4 points bbox to 8 points polygon
32
+ def xywh2xyxypoly(xywh, to_int=True):
33
+ xyxypoly = np.tile(xywh[:, [0, 1]], 4)
34
+ xyxypoly[:, [2, 4]] += xywh[:, [2]]
35
+ xyxypoly[:, [5, 7]] += xywh[:, [3]]
36
+ if to_int:
37
+ xyxypoly = xyxypoly.astype(np.int64)
38
+ return xyxypoly
39
+
40
+ def xyxy2yolo(xyxy, w: int, h: int):
41
+ if xyxy == [] or xyxy == np.array([]) or len(xyxy) == 0:
42
+ return None
43
+ if isinstance(xyxy, list):
44
+ xyxy = np.array(xyxy)
45
+ if len(xyxy.shape) == 1:
46
+ xyxy = np.array([xyxy])
47
+ yolo = np.copy(xyxy).astype(np.float64)
48
+ yolo[:, [0, 2]] = yolo[:, [0, 2]] / w
49
+ yolo[:, [1, 3]] = yolo[:, [1, 3]] / h
50
+ yolo[:, [2, 3]] -= yolo[:, [0, 1]]
51
+ yolo[:, [0, 1]] += yolo[:, [2, 3]] / 2
52
+ return yolo
53
+
54
+ def yolo_xywh2xyxy(xywh: np.array, w: int, h: int, to_int=True):
55
+ if xywh is None:
56
+ return None
57
+ if len(xywh) == 0:
58
+ return None
59
+ if len(xywh.shape) == 1:
60
+ xywh = np.array([xywh])
61
+ xywh[:, [0, 2]] *= w
62
+ xywh[:, [1, 3]] *= h
63
+ xywh[:, [0, 1]] -= xywh[:, [2, 3]] / 2
64
+ xywh[:, [2, 3]] += xywh[:, [0, 1]]
65
+ if to_int:
66
+ xywh = xywh.astype(np.int64)
67
+ return xywh
68
+
69
+ def rotate_polygons(center, polygons, rotation, new_center=None, to_int=True):
70
+ if new_center is None:
71
+ new_center = center
72
+ rotation = np.deg2rad(rotation)
73
+ s, c = np.sin(rotation), np.cos(rotation)
74
+ polygons = polygons.astype(np.float32)
75
+
76
+ polygons[:, 1::2] -= center[1]
77
+ polygons[:, ::2] -= center[0]
78
+ rotated = np.copy(polygons)
79
+ rotated[:, 1::2] = polygons[:, 1::2] * c - polygons[:, ::2] * s
80
+ rotated[:, ::2] = polygons[:, 1::2] * s + polygons[:, ::2] * c
81
+ rotated[:, 1::2] += new_center[1]
82
+ rotated[:, ::2] += new_center[0]
83
+ if to_int:
84
+ return rotated.astype(np.int64)
85
+ return rotated
86
+
87
+ def letterbox(im, new_shape=(640, 640), color=(0, 0, 0), auto=False, scaleFill=False, scaleup=True, stride=128):
88
+ # Resize and pad image while meeting stride-multiple constraints
89
+ shape = im.shape[:2] # current shape [height, width]
90
+ if not isinstance(new_shape, tuple):
91
+ new_shape = (new_shape, new_shape)
92
+
93
+ # Scale ratio (new / old)
94
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
95
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
96
+ r = min(r, 1.0)
97
+
98
+ # Compute padding
99
+ ratio = r, r # width, height ratios
100
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
101
+ dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
102
+ if auto: # minimum rectangle
103
+ dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
104
+ elif scaleFill: # stretch
105
+ dw, dh = 0.0, 0.0
106
+ new_unpad = (new_shape[1], new_shape[0])
107
+ ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
108
+
109
+ # dw /= 2 # divide padding into 2 sides
110
+ # dh /= 2
111
+ dh, dw = int(dh), int(dw)
112
+
113
+ if shape[::-1] != new_unpad: # resize
114
+ im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
115
+ top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
116
+ left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
117
+ im = cv2.copyMakeBorder(im, 0, dh, 0, dw, cv2.BORDER_CONSTANT, value=color) # add border
118
+ return im, ratio, (dw, dh)
119
+
120
+ def resize_keepasp(im, new_shape=640, scaleup=True, interpolation=cv2.INTER_LINEAR, stride=None):
121
+ shape = im.shape[:2] # current shape [height, width]
122
+
123
+ if new_shape is not None:
124
+ if not isinstance(new_shape, tuple):
125
+ new_shape = (new_shape, new_shape)
126
+ else:
127
+ new_shape = shape
128
+
129
+ # Scale ratio (new / old)
130
+ r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
131
+ if not scaleup: # only scale down, do not scale up (for better val mAP)
132
+ r = min(r, 1.0)
133
+
134
+ new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
135
+
136
+ if stride is not None:
137
+ h, w = new_unpad
138
+ if h % stride != 0 :
139
+ new_h = (stride - (h % stride)) + h
140
+ else :
141
+ new_h = h
142
+ if w % stride != 0 :
143
+ new_w = (stride - (w % stride)) + w
144
+ else :
145
+ new_w = w
146
+ new_unpad = (new_h, new_w)
147
+
148
+ if shape[::-1] != new_unpad: # resize
149
+ im = cv2.resize(im, new_unpad, interpolation=interpolation)
150
+ return im
151
+
152
+ def expand_textwindow(img_size, xyxy, expand_r=8, shrink=False):
153
+ im_h, im_w = img_size[:2]
154
+ x1, y1 , x2, y2 = xyxy
155
+ w = x2 - x1
156
+ h = y2 - y1
157
+ paddings = int(round((max(h, w) * 0.25 + min(h, w) * 0.75) / expand_r))
158
+ if shrink:
159
+ paddings *= -1
160
+ x1, y1 = max(0, x1 - paddings), max(0, y1 - paddings)
161
+ x2, y2 = min(im_w-1, x2+paddings), min(im_h-1, y2+paddings)
162
+ return [x1, y1, x2, y2]
163
+
164
+ def enlarge_window(rect, im_w, im_h, ratio=2.5, aspect_ratio=1.0) -> List:
165
+ assert ratio > 1.0
166
+
167
+ x1, y1, x2, y2 = rect
168
+ w = x2 - x1
169
+ h = y2 - y1
170
+
171
+ if w <= 0 or h <= 0:
172
+ return [0, 0, 0, 0]
173
+
174
+ # https://numpy.org/doc/stable/reference/generated/numpy.roots.html
175
+ coeff = [aspect_ratio, w+h*aspect_ratio, (1-ratio)*w*h]
176
+ roots = np.roots(coeff)
177
+ roots.sort()
178
+ delta = int(round(roots[-1] / 2 ))
179
+ delta_w = int(delta * aspect_ratio)
180
+ delta_w = min(x1, im_w - x2, delta_w)
181
+ delta = min(y1, im_h - y2, delta)
182
+ rect = np.array([x1-delta_w, y1-delta, x2+delta_w, y2+delta], dtype=np.int64)
183
+ rect[::2] = np.clip(rect[::2], 0, im_w)
184
+ rect[1::2] = np.clip(rect[1::2], 0, im_h)
185
+ return rect.tolist()
186
+
187
+ def draw_connected_labels(num_labels, labels, stats, centroids, names="draw_connected_labels", skip_background=True):
188
+ labdraw = np.zeros((labels.shape[0], labels.shape[1], 3), dtype=np.uint8)
189
+ max_ind = 0
190
+ if isinstance(num_labels, int):
191
+ num_labels = range(num_labels)
192
+
193
+ # for ind, lab in enumerate((range(num_labels))):
194
+ for lab in num_labels:
195
+ if skip_background and lab == 0:
196
+ continue
197
+ randcolor = (random.randint(0,255), random.randint(0,255), random.randint(0,255))
198
+ labdraw[np.where(labels==lab)] = randcolor
199
+ maxr, minr = 0.5, 0.001
200
+ maxw, maxh = stats[max_ind][2] * maxr, stats[max_ind][3] * maxr
201
+ minarea = labdraw.shape[0] * labdraw.shape[1] * minr
202
+
203
+ stat = stats[lab]
204
+ bboxarea = stat[2] * stat[3]
205
+ if stat[2] < maxw and stat[3] < maxh and bboxarea > minarea:
206
+ pix = np.zeros((labels.shape[0], labels.shape[1]), dtype=np.uint8)
207
+ pix[np.where(labels==lab)] = 255
208
+
209
+ rect = cv2.minAreaRect(cv2.findNonZero(pix))
210
+ box = np.int0(cv2.boxPoints(rect))
211
+ labdraw = cv2.drawContours(labdraw, [box], 0, randcolor, 2)
212
+ labdraw = cv2.circle(labdraw, (int(centroids[lab][0]),int(centroids[lab][1])), radius=5, color=(random.randint(0,255), random.randint(0,255), random.randint(0,255)), thickness=-1)
213
+
214
+ cv2.imshow(names, labdraw)
215
+ return labdraw
216
+
217
+ def rotate_image(mat: np.ndarray, angle: float) -> np.ndarray:
218
+ """
219
+ Rotates an image (angle in degrees) and expands image to avoid cropping
220
+ # https://stackoverflow.com/questions/43892506/opencv-python-rotate-image-without-cropping-sides
221
+ """
222
+
223
+ height, width = mat.shape[:2] # image shape has 3 dimensions
224
+ image_center = (width/2, height/2) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape
225
+
226
+ rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.)
227
+
228
+ # rotation calculates the cos and sin, taking absolutes of those.
229
+ abs_cos = abs(rotation_mat[0,0])
230
+ abs_sin = abs(rotation_mat[0,1])
231
+
232
+ # find the new width and height bounds
233
+ bound_w = int(height * abs_sin + width * abs_cos)
234
+ bound_h = int(height * abs_cos + width * abs_sin)
235
+
236
+ # subtract old image center (bringing image back to origo) and adding the new image center coordinates
237
+ rotation_mat[0, 2] += bound_w/2 - image_center[0]
238
+ rotation_mat[1, 2] += bound_h/2 - image_center[1]
239
+
240
+ # rotate image with the new bounds and translated rotation matrix
241
+ rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h))
242
+ return rotated_mat
243
+
244
+ def color_difference(rgb1: List, rgb2: List) -> float:
245
+ # https://en.wikipedia.org/wiki/Color_difference#CIE76
246
+ color1 = np.array(rgb1, dtype=np.uint8).reshape(1, 1, 3)
247
+ color2 = np.array(rgb2, dtype=np.uint8).reshape(1, 1, 3)
248
+ diff = cv2.cvtColor(color1, cv2.COLOR_RGB2LAB).astype(np.float64) - cv2.cvtColor(color2, cv2.COLOR_RGB2LAB).astype(np.float64)
249
+ diff[..., 0] *= 0.392
250
+ diff = np.linalg.norm(diff, axis=2)
251
+ return diff.item()
252
+
253
+ def extract_ballon_region(img: np.ndarray, ballon_rect: List, show_process=False, enlarge_ratio=2.0, cal_region_rect=False) -> Tuple[np.ndarray, int, List]:
254
+ WHITE = (255, 255, 255)
255
+ BLACK = (0, 0, 0)
256
+
257
+ x1, y1, x2, y2 = ballon_rect[0], ballon_rect[1], \
258
+ ballon_rect[2] + ballon_rect[0], ballon_rect[3] + ballon_rect[1]
259
+ if enlarge_ratio > 1:
260
+ x1, y1, x2, y2 = enlarge_window([x1, y1, x2, y2], img.shape[1], img.shape[0], enlarge_ratio, aspect_ratio=ballon_rect[3] / ballon_rect[2])
261
+
262
+ img = img[y1:y2, x1:x2].copy()
263
+
264
+ kernel = np.ones((3,3),np.uint8)
265
+ orih, oriw = img.shape[0], img.shape[1]
266
+ scaleR = 1
267
+ if orih > 300 and oriw > 300:
268
+ scaleR = 0.6
269
+ elif orih < 120 or oriw < 120:
270
+ scaleR = 1.4
271
+
272
+ if scaleR != 1:
273
+ h, w = img.shape[0], img.shape[1]
274
+ orimg = np.copy(img)
275
+ img = cv2.resize(img, (int(w*scaleR), int(h*scaleR)), interpolation=cv2.INTER_AREA)
276
+ h, w = img.shape[0], img.shape[1]
277
+ img_area = h * w
278
+
279
+ cpimg = cv2.GaussianBlur(img,(3,3),cv2.BORDER_DEFAULT)
280
+ detected_edges = cv2.Canny(cpimg, 70, 140, L2gradient=True, apertureSize=3)
281
+ cv2.rectangle(detected_edges, (0, 0), (w-1, h-1), WHITE, 1, cv2.LINE_8)
282
+ cons, hiers = cv2.findContours(detected_edges, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
283
+ cv2.rectangle(detected_edges, (0, 0), (w-1, h-1), BLACK, 1, cv2.LINE_8)
284
+
285
+ ballon_mask, outer_index = np.zeros((h, w), np.uint8), -1
286
+ min_retval = np.inf
287
+ mask = np.zeros((h, w), np.uint8)
288
+ difres = 10
289
+ seedpnt = (int(w/2), int(h/2))
290
+ for ii in range(len(cons)):
291
+ rect = cv2.boundingRect(cons[ii])
292
+ if rect[2]*rect[3] < img_area*0.4:
293
+ continue
294
+
295
+ mask = cv2.drawContours(mask, cons, ii, (255), 2)
296
+ cpmask = np.copy(mask)
297
+ cv2.rectangle(mask, (0, 0), (w-1, h-1), WHITE, 1, cv2.LINE_8)
298
+ retval, _, _, rect = cv2.floodFill(cpmask, mask=None, seedPoint=seedpnt, flags=4, newVal=(127), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
299
+
300
+ if retval <= img_area * 0.3:
301
+ mask = cv2.drawContours(mask, cons, ii, (0), 2)
302
+ if retval < min_retval and retval > img_area * 0.3:
303
+ min_retval = retval
304
+ ballon_mask = cpmask
305
+
306
+ ballon_mask = 127 - ballon_mask
307
+ ballon_mask = cv2.dilate(ballon_mask, kernel,iterations = 1)
308
+ ballon_area, _, _, rect = cv2.floodFill(ballon_mask, mask=None, seedPoint=seedpnt, flags=4, newVal=(30), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
309
+ ballon_mask = 30 - ballon_mask
310
+ retval, ballon_mask = cv2.threshold(ballon_mask, 1, 255, cv2.THRESH_BINARY)
311
+ ballon_mask = cv2.bitwise_not(ballon_mask, ballon_mask)
312
+
313
+ box_kernel = int(np.sqrt(ballon_area) / 30)
314
+ if box_kernel > 1:
315
+ box_kernel = np.ones((box_kernel,box_kernel),np.uint8)
316
+ ballon_mask = cv2.dilate(ballon_mask, box_kernel, iterations = 1)
317
+ ballon_mask = cv2.erode(ballon_mask, box_kernel, iterations = 1)
318
+
319
+ if scaleR != 1:
320
+ img = orimg
321
+ ballon_mask = cv2.resize(ballon_mask, (oriw, orih))
322
+
323
+ if show_process:
324
+ cv2.imshow('ballon_mask', ballon_mask)
325
+ cv2.imshow('img', img)
326
+ cv2.waitKey(0)
327
+ if cal_region_rect:
328
+ return ballon_mask, (ballon_mask > 0).sum(), [x1, y1, x2, y2], cv2.boundingRect(ballon_mask)
329
+ return ballon_mask, (ballon_mask > 0).sum(), [x1, y1, x2, y2]
330
+
331
+ def square_pad_resize(img: np.ndarray, tgt_size: int):
332
+ h, w = img.shape[:2]
333
+ pad_h, pad_w = 0, 0
334
+
335
+ # make square image
336
+ if w < h:
337
+ pad_w = h - w
338
+ w += pad_w
339
+ elif h < w:
340
+ pad_h = w - h
341
+ h += pad_h
342
+
343
+ pad_size = tgt_size - h
344
+ if pad_size > 0:
345
+ pad_h += pad_size
346
+ pad_w += pad_size
347
+
348
+ if pad_h > 0 or pad_w > 0:
349
+ img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT)
350
+
351
+ down_scale_ratio = tgt_size / img.shape[0]
352
+ assert down_scale_ratio <= 1
353
+ if down_scale_ratio < 1:
354
+ img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA)
355
+
356
+ return img, down_scale_ratio, pad_h, pad_w
357
+
358
+
359
+
360
+ def get_block_mask(xywh: List, mask_array: np.ndarray, angle: int):
361
+ x, y, w, h = xywh
362
+ im_h, im_w = mask_array.shape[:2]
363
+
364
+ if angle != 0:
365
+ cx, cy = x + int(round(w / 2)), y + int(round(h / 2))
366
+ poly = xywh2xyxypoly(np.array([[x, y, w, h]]))
367
+ poly = rotate_polygons([cx, cy], poly, -angle)
368
+
369
+ x1, x2 = np.min(poly[..., ::2]), np.max(poly[..., ::2])
370
+ y1, y2 = np.min(poly[..., 1::2]), np.max(poly[..., 1::2])
371
+
372
+ if x2 < 0 or x2 - x1 < 2 or x1 >= im_w - 1 \
373
+ or y2 < 0 or y2 - y1 < 2 or y1 >= im_h - 1:
374
+ return None, None
375
+ else:
376
+ poly[..., ::2] -= cx - int((x2 - x1) / 2)
377
+ poly[..., 1::2] -= cy - int((y2 - y1) / 2)
378
+ itmsk = np.zeros((y2 - y1, x2 - x1), np.uint8)
379
+
380
+ cv2.fillPoly(itmsk, poly.reshape(-1, 4, 2), color=(255))
381
+ px1, px2, py1, py2 = 0, itmsk.shape[1], 0, itmsk.shape[0]
382
+ if x1 < 0:
383
+ px1 = -x1
384
+ x1 = 0
385
+ if x2 > im_w:
386
+ px2 = im_w - x2
387
+ x2 = im_w
388
+ if y1 < 0:
389
+ py1 = -y1
390
+ y1 = 0
391
+ if y2 > im_h:
392
+ py2 = im_h - y2
393
+ y2 = im_h
394
+ itmsk = itmsk[py1: py2, px1: px2]
395
+ msk = cv2.bitwise_and(mask_array[y1: y2, x1: x2], itmsk)
396
+ else:
397
+ x1, y1, x2, y2 = x, y, x+w, y+h
398
+ if x2 < 0 or x2 - x1 < 2 or x1 >= im_w - 1 \
399
+ or y2 < 0 or y2 - y1 < 2 or y1 >= im_h - 1:
400
+ return None, None
401
+ else:
402
+ if x1 < 0:
403
+ x1 = 0
404
+ if x2 > im_w:
405
+ x2 = im_w
406
+ if y1 < 0:
407
+ y1 = 0
408
+ if y2 > im_h:
409
+ y2 = im_h
410
+ msk = mask_array[y1: y2, x1: x2]
411
+
412
+ return msk, [x1, y1, x2, y2]
413
+
utils/io_utils.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, os, sys, time, io
2
+ import os.path as osp
3
+ from pathlib import Path
4
+ import importlib
5
+ from typing import List, Dict, Callable, Union
6
+ import base64
7
+ import traceback
8
+
9
+ from .logger import logger as LOGGER
10
+ import requests
11
+ from PIL import Image
12
+ import PIL
13
+ import cv2
14
+ import numpy as np
15
+ import pillow_jxl
16
+ from natsort import natsorted
17
+
18
+ IMG_EXT = ['.bmp', '.jpg', '.png', '.jpeg', '.webp', '.jxl']
19
+
20
+ NP_INT_TYPES = (np.int_, np.int8, np.int16, np.int32, np.int64, np.uint, np.uint8, np.uint16, np.uint32, np.uint64)
21
+ if int(np.version.full_version.split('.')[0]) == 1:
22
+ NP_BOOL_TYPES = (np.bool_, np.bool8)
23
+ NP_FLOAT_TYPES = (np.float_, np.float16, np.float32, np.float64)
24
+ else:
25
+ NP_BOOL_TYPES = (np.bool_, np.bool)
26
+ NP_FLOAT_TYPES = (np.float16, np.float32, np.float64)
27
+
28
+ def to_dict(obj):
29
+ return json.loads(json.dumps(obj, default=lambda o: o.__dict__, ensure_ascii=False))
30
+
31
+ def serialize_np(obj):
32
+ if isinstance(obj, np.ndarray):
33
+ return obj.tolist()
34
+ elif isinstance(obj, np.ScalarType):
35
+ if isinstance(obj, NP_BOOL_TYPES):
36
+ return bool(obj)
37
+ elif isinstance(obj, NP_FLOAT_TYPES):
38
+ return float(obj)
39
+ elif isinstance(obj, NP_INT_TYPES):
40
+ return int(obj)
41
+ return obj
42
+
43
+ def json_dump_nested_obj(obj, **kwargs):
44
+ def _default(obj):
45
+ if isinstance(obj, (np.ndarray, np.ScalarType)):
46
+ return serialize_np(obj)
47
+ return obj.__dict__
48
+ return json.dumps(obj, default=lambda o: _default(o), ensure_ascii=False, **kwargs)
49
+
50
+ # https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable
51
+ class NumpyEncoder(json.JSONEncoder):
52
+ def default(self, obj):
53
+ if isinstance(obj, (np.ndarray, np.ScalarType)):
54
+ return serialize_np(obj)
55
+ return json.JSONEncoder.default(self, obj)
56
+
57
+ def find_all_imgs(img_dir, abs_path=False, sort=False):
58
+ imglist = []
59
+ for filename in os.listdir(img_dir):
60
+ file_suffix = Path(filename).suffix
61
+ if file_suffix.lower() not in IMG_EXT:
62
+ continue
63
+ if abs_path:
64
+ imglist.append(osp.join(img_dir, filename))
65
+ else:
66
+ imglist.append(filename)
67
+
68
+ if sort:
69
+ imglist = natsorted(imglist)
70
+
71
+ return imglist
72
+
73
+ def find_all_files_recursive(tgt_dir: Union[List, str], ext: Union[List, set], exclude_dirs=None):
74
+ if isinstance(tgt_dir, str):
75
+ tgt_dir = [tgt_dir]
76
+
77
+ if exclude_dirs is None:
78
+ exclude_dirs = set()
79
+
80
+ filelst = []
81
+ for d in tgt_dir:
82
+ for root, _, files in os.walk(d):
83
+ if osp.basename(root) in exclude_dirs:
84
+ continue
85
+ for f in files:
86
+ if Path(f).suffix.lower() in ext:
87
+ filelst.append(osp.join(root, f))
88
+
89
+ return filelst
90
+
91
+ def imread(imgpath, read_type=cv2.IMREAD_COLOR, max_retry_limit=5, retry_interval=0.1):
92
+ if not osp.exists(imgpath):
93
+ return None
94
+
95
+ num_tries = 0
96
+ while True:
97
+ try:
98
+ img = Image.open(imgpath)
99
+ if read_type != cv2.IMREAD_GRAYSCALE:
100
+ img = img.convert('RGB')
101
+ img = np.array(img)
102
+ break
103
+ except PIL.UnidentifiedImageError as e:
104
+ # IMG I/O thread might not finished yet
105
+ num_tries += 1
106
+ if max_retry_limit is not None and num_tries >= max_retry_limit:
107
+ LOGGER.exception(e)
108
+ return None
109
+ LOGGER.warning(f'PIL.UnidentifiedImageError: failed to read {imgpath}, retries: {num_tries} / {max_retry_limit}')
110
+ time.sleep(retry_interval)
111
+
112
+ if read_type == cv2.IMREAD_GRAYSCALE:
113
+ if img.ndim == 3:
114
+ if img.shape[-1] == 3:
115
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
116
+ elif img.shape[-1] == 4:
117
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
118
+ elif img.shape[-1] == 1:
119
+ img = img[..., 0]
120
+ else:
121
+ raise
122
+ return img
123
+
124
+
125
+ def imwrite(img_path, img, ext='.png', quality=100, jxl_encode_effort=3):
126
+ # cv2 writing is faster than PIL
127
+ suffix = Path(img_path).suffix
128
+ ext = ext.lower()
129
+ assert ext in IMG_EXT
130
+ if suffix != '':
131
+ img_path = img_path.replace(suffix, ext)
132
+ else:
133
+ img_path += ext
134
+ encode_param = None
135
+ if ext in {'.jpg', '.jpeg'}:
136
+ encode_param = [cv2.IMWRITE_JPEG_QUALITY, quality]
137
+ elif ext == '.webp':
138
+ encode_param = [cv2.IMWRITE_WEBP_QUALITY, quality]
139
+ if ext == '.jxl':
140
+ # jxl_encode_effort: https://github.com/Isotr0py/pillow-jpegxl-plugin/issues/23
141
+ # higher values theoretically produce smaller files at the expense of time, 3 seems to strike a balance
142
+ lossless = quality > 99 # quality=100, lossless=False seems to result in larger file compared with lossless=True
143
+ Image.fromarray(img).save(img_path, quality=quality, lossless=lossless, effort=jxl_encode_effort)
144
+ else:
145
+ if len(img.shape) == 3:
146
+ if img.shape[-1] == 3:
147
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
148
+ elif img.shape[-1] == 4:
149
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
150
+ cv2.imencode(ext, img, encode_param)[1].tofile(img_path)
151
+
152
+
153
+ def show_img_by_dict(imgdicts):
154
+ for keyname in imgdicts.keys():
155
+ cv2.imshow(keyname, imgdicts[keyname])
156
+ cv2.waitKey(0)
157
+
158
+ def text_is_empty(text) -> bool:
159
+ if isinstance(text, str):
160
+ if text.strip() == '':
161
+ return True
162
+ if isinstance(text, list):
163
+ for t in text:
164
+ t_is_empty = text_is_empty(t)
165
+ if not t_is_empty:
166
+ return False
167
+ return True
168
+ elif text is None:
169
+ return True
170
+
171
+ def empty_func(*args, **kwargs):
172
+ return
173
+
174
+ def get_obj_from_str(string, reload=False):
175
+ module, cls = string.rsplit(".", 1)
176
+ if reload:
177
+ module_imp = importlib.import_module(module)
178
+ importlib.reload(module_imp)
179
+ return getattr(importlib.import_module(module, package=None), cls)
180
+
181
+ def get_module_from_str(module_str: str):
182
+ return importlib.import_module(module_str, package=None)
183
+
184
+ def build_funcmap(module_str: str, params_names: List[str], func_prefix: str = '', func_suffix: str = '', fallback_func: Callable = None, verbose: bool = True) -> Dict:
185
+
186
+ if fallback_func is None:
187
+ fallback_func = empty_func
188
+
189
+ module = get_module_from_str(module_str)
190
+
191
+ funcmap = {}
192
+ for param in params_names:
193
+ tgt_func = f'{func_prefix}{param}{func_suffix}'
194
+ try:
195
+ tgt_func = getattr(module, tgt_func)
196
+ except Exception as e:
197
+ if verbose:
198
+ print(f'failed to import {tgt_func} from {module_str}: {e}')
199
+ tgt_func = fallback_func
200
+ funcmap[param] = tgt_func
201
+
202
+ return funcmap
203
+
204
+ def _b64encode(x: bytes) -> str:
205
+ return base64.b64encode(x).decode("utf-8")
206
+
207
+ def img2b64(img):
208
+ """
209
+ Convert a PIL image to a base64-encoded string.
210
+ """
211
+ if isinstance(img, np.ndarray):
212
+ img = Image.fromarray(img)
213
+ buffered = io.BytesIO()
214
+ img.save(buffered, format='PNG')
215
+ return _b64encode(buffered.getvalue())
216
+
217
+ def save_encoded_image(b64_image: str, output_path: str):
218
+ with open(output_path, "wb") as image_file:
219
+ image_file.write(base64.b64decode(b64_image))
220
+
221
+ def submit_request(url, data, exist_on_exception=True, auth=None, wait_time = 5):
222
+ response = None
223
+ try:
224
+ while True:
225
+ try:
226
+ response = requests.post(url, data=data, auth=auth)
227
+ response.raise_for_status()
228
+ break
229
+ except Exception as e:
230
+ if wait_time > 0:
231
+ print(traceback.format_exc(), file=sys.stderr)
232
+ print(f'sleep {wait_time} sec...')
233
+ time.sleep(wait_time)
234
+ continue
235
+ else:
236
+ raise e
237
+ except Exception as e:
238
+ print(traceback.format_exc(), file=sys.stderr)
239
+ if response is not None:
240
+ print('response content: ' + response.text)
241
+ if exist_on_exception:
242
+ exit()
243
+ return response
utils/logger.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import os
4
+ import os.path as osp
5
+ from glob import glob
6
+ import termcolor
7
+
8
+
9
+ if os.name == "nt": # Windows
10
+ import colorama
11
+ colorama.init()
12
+
13
+
14
+ COLORS = {
15
+ "WARNING": "yellow",
16
+ "INFO": "white",
17
+ "DEBUG": "blue",
18
+ "CRITICAL": "red",
19
+ "ERROR": "red",
20
+ }
21
+
22
+
23
+ class ColoredFormatter(logging.Formatter):
24
+ def __init__(self, fmt, use_color=True):
25
+ logging.Formatter.__init__(self, fmt)
26
+ self.use_color = use_color
27
+
28
+ def format(self, record):
29
+ levelname = record.levelname
30
+ if self.use_color and levelname in COLORS:
31
+
32
+ def colored(text):
33
+ return termcolor.colored(
34
+ text,
35
+ color=COLORS[levelname],
36
+ attrs={"bold": True},
37
+ )
38
+
39
+ record.levelname2 = colored("{:<7}".format(record.levelname))
40
+ record.message2 = colored(record.getMessage())
41
+
42
+ asctime2 = datetime.datetime.fromtimestamp(record.created)
43
+ record.asctime2 = termcolor.colored(asctime2, color="green")
44
+
45
+ record.module2 = termcolor.colored(record.module, color="cyan")
46
+ record.funcName2 = termcolor.colored(record.funcName, color="cyan")
47
+ record.lineno2 = termcolor.colored(record.lineno, color="cyan")
48
+ return logging.Formatter.format(self, record)
49
+
50
+ FORMAT = (
51
+ "[%(levelname2)s] %(module2)s:%(funcName2)s:%(lineno2)s - %(message2)s"
52
+ )
53
+
54
+ class ColoredLogger(logging.Logger):
55
+
56
+ def __init__(self, name):
57
+ logging.Logger.__init__(self, name, logging.INFO)
58
+
59
+ color_formatter = ColoredFormatter(FORMAT)
60
+
61
+ console = logging.StreamHandler()
62
+ console.setFormatter(color_formatter)
63
+
64
+ self.addHandler(console)
65
+ return
66
+
67
+
68
+ def setup_logging(logfile_dir: str, max_num_logs=14):
69
+
70
+ if not osp.exists(logfile_dir):
71
+ os.makedirs(logfile_dir)
72
+ else:
73
+ old_logs = glob(osp.join(logfile_dir, '*.log'))
74
+ old_logs.sort()
75
+ n_log = len(old_logs)
76
+ if n_log >= max_num_logs:
77
+ to_remove = n_log - max_num_logs + 1
78
+ try:
79
+ for ii in range(to_remove):
80
+ os.remove(old_logs[ii])
81
+ except Exception as e:
82
+ logger.error(e)
83
+
84
+ logfilename = datetime.datetime.now().strftime('_%Y_%m_%d-%H_%M_%S.log')
85
+ logfilep = osp.join(logfile_dir, logfilename)
86
+ fh = logging.FileHandler(logfilep, mode='w', encoding='utf-8')
87
+ fh.setFormatter(
88
+ logging.Formatter(
89
+ ("[%(levelname)s] %(module)s:%(funcName)s:%(lineno)s - %(message)s")
90
+ )
91
+ )
92
+ fh.setLevel(logging.DEBUG)
93
+ logger.addHandler(fh)
94
+
95
+
96
+ logging.setLoggerClass(ColoredLogger)
97
+ logger = logging.getLogger('BallonTranslator')
98
+ logger.setLevel(logging.DEBUG)
99
+ logger.propagate = False
utils/message.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from typing import Callable, List, Dict
3
+
4
+ from . import shared
5
+ from .logger import logger as LOGGER
6
+
7
+
8
+ def create_error_dialog(exception: Exception, error_msg: str = None, exception_type: str = None):
9
+ '''
10
+ Popup a error dialog in main thread
11
+ Args:
12
+ error_msg: Description text prepend before str(exception)
13
+ exception_type: Specify it to avoid errors dialog of the same type popup repeatedly
14
+ '''
15
+
16
+ detail_traceback = traceback.format_exc()
17
+
18
+ if exception_type is None:
19
+ exception_type = ''
20
+
21
+ exception_type_empty = exception_type == ''
22
+ show_exception = exception_type_empty or exception_type not in shared.showed_exception
23
+
24
+ if show_exception:
25
+ if error_msg is None:
26
+ error_msg = str(exception)
27
+ else:
28
+ error_msg = str(exception) + '\n' + error_msg
29
+ LOGGER.error(error_msg + '\n')
30
+ LOGGER.error(detail_traceback)
31
+
32
+ if not shared.HEADLESS:
33
+ shared.create_errdialog_in_mainthread(error_msg, detail_traceback, exception_type)
34
+
35
+
36
+ def create_info_dialog(info_msg, btn_type=None, modal: bool = False, frame_less: bool = False, signal_slot_map_list: List[Dict] = None):
37
+ '''
38
+ Popup a info dialog in main thread
39
+ '''
40
+ LOGGER.info(info_msg)
41
+ if not shared.HEADLESS:
42
+ shared.create_infodialog_in_mainthread({'info_msg': info_msg, 'btn_type': btn_type, 'modal': modal, 'frame_less': frame_less, 'signal_slot_map_list': signal_slot_map_list})
43
+
44
+
45
+ def connect_once(signal, exec_func: Callable):
46
+ '''
47
+ signal.emit will only trigger exec_func once
48
+ '''
49
+
50
+ def _disconnect_after_called(*func_args, **func_kwargs):
51
+
52
+ def _try_disconnect():
53
+ try:
54
+ signal.disconnect(connect_func)
55
+ except:
56
+ print('Failed to disconnect')
57
+ print(traceback.format_exc())
58
+
59
+ try:
60
+ exec_func(*func_args, **func_kwargs)
61
+ except Exception as e:
62
+ _try_disconnect()
63
+ raise e
64
+ _try_disconnect()
65
+
66
+ connect_func = _disconnect_after_called
67
+ signal.connect(_disconnect_after_called)
utils/package.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copied from https://github.com/HansBug/hbutils/blob/main/hbutils/system/python/package.py
2
+ # to replace the deprecated pkg_resources
3
+
4
+ import functools
5
+ import itertools
6
+ import os
7
+ import pathlib
8
+ import subprocess
9
+ import sys
10
+ from typing import List, Optional
11
+
12
+ from packaging.requirements import Requirement
13
+ from packaging.utils import canonicalize_name
14
+
15
+ try:
16
+ import importlib.metadata as importlib_metadata
17
+ except (ModuleNotFoundError, ImportError):
18
+ import importlib_metadata
19
+ from packaging.version import Version
20
+
21
+
22
+ def package_version(name: str) -> Optional[Version]:
23
+ """
24
+ Overview:
25
+ Get version of package with given ``name``.
26
+
27
+ :param name: Name of the package, case is not sensitive.
28
+ :return: A :class:`packing.version.Version` object. If the package is not installed, return ``None``.
29
+
30
+ Examples::
31
+ >>> from hbutils.system import package_version
32
+ >>>
33
+ >>> package_version('pip')
34
+ <Version('21.3.1')>
35
+ >>> package_version('setuptools')
36
+ <Version('59.6.0')>
37
+ >>> package_version('not_a_package')
38
+ None
39
+ """
40
+ try:
41
+ return Version(importlib_metadata.distribution(canonicalize_name(name)).version)
42
+ except importlib_metadata.PackageNotFoundError:
43
+ return None
44
+
45
+
46
+ def _nonblank(text):
47
+ return text and not text.startswith('#')
48
+
49
+
50
+ @functools.singledispatch
51
+ def yield_lines(iterable):
52
+ r"""
53
+ Based on https://github.com/jaraco/jaraco.text/blob/main/jaraco/text/__init__.py#L537 .
54
+ Yield valid lines of a string or iterable.
55
+ >>> list(yield_lines(''))
56
+ []
57
+ >>> list(yield_lines(['foo', 'bar']))
58
+ ['foo', 'bar']
59
+ >>> list(yield_lines('foo\nbar'))
60
+ ['foo', 'bar']
61
+ >>> list(yield_lines('\nfoo\n#bar\nbaz #comment'))
62
+ ['foo', 'baz #comment']
63
+ >>> list(yield_lines(['foo\nbar', 'baz', 'bing\n\n\n']))
64
+ ['foo', 'bar', 'baz', 'bing']
65
+ """
66
+ return itertools.chain.from_iterable(map(yield_lines, iterable))
67
+
68
+
69
+ @yield_lines.register(str)
70
+ def _(text):
71
+ return filter(_nonblank, map(str.strip, text.splitlines()))
72
+
73
+
74
+ def drop_comment(line):
75
+ """
76
+ Based on https://github.com/jaraco/jaraco.text/blob/main/jaraco/text/__init__.py#L560 .
77
+ Drop comments.
78
+ >>> drop_comment('foo # bar')
79
+ 'foo'
80
+ A hash without a space may be in a URL.
81
+ >>> drop_comment('https://example.com/foo#bar')
82
+ 'https://example.com/foo#bar'
83
+ """
84
+ return line.partition(' #')[0]
85
+
86
+
87
+ def join_continuation(lines):
88
+ r"""
89
+ Based on https://github.com/jaraco/jaraco.text/blob/main/jaraco/text/__init__.py#L575 .
90
+ Join lines continued by a trailing backslash.
91
+ >>> list(join_continuation(['foo \\', 'bar', 'baz']))
92
+ ['foobar', 'baz']
93
+ >>> list(join_continuation(['foo \\', 'bar', 'baz']))
94
+ ['foobar', 'baz']
95
+ >>> list(join_continuation(['foo \\', 'bar \\', 'baz']))
96
+ ['foobarbaz']
97
+ Not sure why, but...
98
+ The character preceding the backslash is also elided.
99
+ >>> list(join_continuation(['goo\\', 'dly']))
100
+ ['godly']
101
+ A terrible idea, but...
102
+ If no line is available to continue, suppress the lines.
103
+ >>> list(join_continuation(['foo', 'bar\\', 'baz\\']))
104
+ ['foo']
105
+ """
106
+ lines = iter(lines)
107
+ for item in lines:
108
+ while item.endswith('\\'):
109
+ try: # pragma: no cover
110
+ item = item[:-2].strip() + next(lines)
111
+ except StopIteration:
112
+ return
113
+ yield item
114
+
115
+
116
+ def load_req_file(requirements_file: str) -> List[str]:
117
+ """
118
+ Overview:
119
+ Load requirements items from a ``requirements.txt`` file.
120
+
121
+ :param requirements_file: Requirements file.
122
+ :return requirements: List of requirements.
123
+
124
+ Examples::
125
+ >>> from hbutils.system import load_req_file
126
+ >>> load_req_file('requirements.txt')
127
+ ['packaging>=21.3', 'setuptools>=50.0']
128
+ """
129
+ with pathlib.Path(requirements_file).open() as reqfile:
130
+ return list(map(
131
+ lambda x: str(Requirement(x)),
132
+ join_continuation(map(drop_comment, yield_lines(reqfile)))
133
+ ))
134
+
135
+
136
+ def pip(*args, silent: bool = False):
137
+ """
138
+ Overview:
139
+ Run pip command with code.
140
+
141
+ :param args: Command line arguments for ``pip`` command.
142
+ :param silent: Do not print anything. Default is false, which means print the output to ``sys.stdout`` \
143
+ and ``sys.stderr``.
144
+
145
+ Examples::
146
+ >>> from hbutils.system import pip
147
+ >>> pip('-V')
148
+ pip 22.3.1 from /home/user/myproject/venv/lib/python3.7/site-packages/pip (python 3.7)
149
+ >>> pip('-V', silent=True) # nothing will be printed
150
+ """
151
+ process = subprocess.run(
152
+ [sys.executable, '-m', 'pip', *args],
153
+ stdin=sys.stdin if not silent else None,
154
+ stdout=sys.stdout if not silent else subprocess.PIPE,
155
+ stderr=sys.stderr if not silent else subprocess.PIPE,
156
+ )
157
+ assert not process.returncode, f'Error when calling {process.args!r}{os.linesep}' \
158
+ f'Error Code - {process.returncode}{os.linesep}' \
159
+ f'Stdout:{os.linesep}' \
160
+ f'{process.stdout.decode()}{os.linesep}' \
161
+ f'{os.linesep}' \
162
+ f'Stderr:{os.linesep}' \
163
+ f'{process.stderr.decode()}{os.linesep}'
164
+ process.check_returncode()
165
+
166
+
167
+ def _yield_reqs_to_install(req: Requirement, current_extra: str = ''):
168
+ if req.marker and not req.marker.evaluate({'extra': current_extra}):
169
+ return
170
+
171
+ try:
172
+ version = importlib_metadata.distribution(req.name).version
173
+ except importlib_metadata.PackageNotFoundError: # req not installed
174
+ yield req
175
+ else:
176
+ if req.specifier.contains(version, prereleases=True):
177
+ for child_req in (importlib_metadata.metadata(req.name).get_all('Requires-Dist') or []):
178
+ child_req_obj = Requirement(child_req)
179
+
180
+ need_check, ext = False, None
181
+ for extra in req.extras:
182
+ if child_req_obj.marker and child_req_obj.marker.evaluate({'extra': extra}):
183
+ need_check = True
184
+ ext = extra
185
+ break
186
+
187
+ if need_check: # check for extra reqs
188
+ yield from _yield_reqs_to_install(child_req_obj, ext)
189
+
190
+ else: # main version not match
191
+ yield req
192
+
193
+
194
+ def _check_req(req: Requirement):
195
+ return not bool(list(itertools.islice(_yield_reqs_to_install(req), 1)))
196
+
197
+
198
+ def check_reqs(reqs: List[str]) -> bool:
199
+ """
200
+ Overview:
201
+ Check if the given requirements are all satisfied.
202
+
203
+ :param reqs: List of requirements.
204
+ :return satisfied: All the requirements in ``reqs`` satisfied or not.
205
+
206
+ Examples::
207
+ >>> from hbutils.system import check_reqs
208
+ >>> check_reqs(['pip>=20.0'])
209
+ True
210
+ >>> check_reqs(['pip~=19.2'])
211
+ False
212
+ >>> check_reqs(['pip>=20.0', 'setuptools>=50.0'])
213
+ True
214
+
215
+ .. note::
216
+ If a requirement's marker is not satisfied in this environment,
217
+ **it will be ignored** instead of return ``False``.
218
+ """
219
+ return all(map(lambda x: _check_req(Requirement(x)), reqs))
220
+
221
+
222
+ def check_req_file(requirements_file: str) -> bool:
223
+ """
224
+ Overview:
225
+ Check if the requirements in the given ``requirements_file`` is satisfied.
226
+
227
+ :param requirements_file: Requirements file, such as ``requirements.txt``.
228
+ :return satisfied: All the requirements in ``requirements_file`` satisfied or not.
229
+
230
+ Examples::
231
+ >>> from hbutils.system import check_req_file
232
+ >>>
233
+ >>> check_req_file('requirements.txt')
234
+ True
235
+ >>> check_req_file('requirements-test.txt')
236
+ True
237
+ """
238
+ return check_reqs(load_req_file(requirements_file))
239
+
240
+
241
+ def pip_install(reqs: List[str], silent: bool = False, force: bool = False, user: bool = False):
242
+ """
243
+ Overview:
244
+ Pip install requirements with code.
245
+ Similar to ``pip install req1 req2 ...``.
246
+
247
+ :param reqs: Requirement items to install.
248
+ :param silent: Do not print anything. Default is ``False``.
249
+ :param force: Force execute the ``pip install`` command. Default is ``False`` which means the requirements \
250
+ will be checked before installation, and the installation will be only executed when \
251
+ some requirements not installed.
252
+ :param user: User mode, represents ``--user`` option in ``pip``.
253
+
254
+ Examples::
255
+ >>> from hbutils.system import pip_install
256
+ >>> pip_install(['scikit-learn']) # not installed
257
+ Looking in indexes: https://xxx/simple
258
+ Collecting scikit-learn
259
+ Using cached https://xxx/scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (24.8 MB)
260
+ Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn
261
+ Successfully installed joblib-1.2.0 scikit-learn-1.0.2 scipy-1.7.3 threadpoolctl-3.1.0
262
+ >>> pip_install(['numpy>=1.10.0']) # installed
263
+ >>> pip_install(['numpy>=1.10.0'], force=True) # force execute
264
+ Looking in indexes: https://xxx/simple
265
+ Requirement already satisfied: numpy>=1.10.0 in ./venv/lib/python3.7/site-packages (1.21.6)
266
+ """
267
+ if force or not check_reqs(reqs):
268
+ pip('install', *(('--user',) if user else ()), *reqs, silent=silent)
269
+
270
+
271
+ def pip_install_req_file(requirements_file: str, silent: bool = False, force: bool = False, user: bool = False):
272
+ """
273
+ Overview:
274
+ Pip install requirements from file with code.
275
+ Similar to ``pip install -r requirements.txt``.
276
+
277
+ :param requirements_file: Requirements file, such as ``requirements.txt``.
278
+ :param silent: Do not print anything. Default is ``False``.
279
+ :param force: Force execute the ``pip install`` command. Default is ``False`` which means the requirements \
280
+ will be checked before installation, and the installation will be only executed when \
281
+ some requirements not installed.
282
+ :param user: User mode, represents ``--user`` option in ``pip``.
283
+
284
+ Examples::
285
+ >>> from hbutils.system import pip_install_req_file
286
+ >>> pip_install_req_file('requirements.txt') # pip install -r requirements.txt
287
+ """
288
+ if force or not check_req_file(requirements_file):
289
+ pip('install', *(('--user',) if user else ()), '-r', requirements_file, silent=silent)
utils/proj_imgtrans.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, shutil, re, docx, docx2txt, piexif, cv2
2
+ from docx.shared import Inches
3
+ from docx import Document
4
+ import piexif.helper
5
+ import numpy as np
6
+ import os.path as osp
7
+ from typing import Tuple, Union, List, Dict
8
+ from PIL import Image
9
+
10
+ from utils.watermark_utils import apply_watermark_to_pil_image
11
+ from .logger import logger as LOGGER
12
+ from .io_utils import find_all_imgs, imread, imwrite, NumpyEncoder
13
+ from .textblock import TextBlock, FontFormat
14
+ from .config import pcfg
15
+ from . import shared
16
+ from .exceptions import ImgnameNotInProjectException, ProjectLoadFailureException, ProjectDirNotExistException, ProjectNotSupportedException
17
+
18
+ class ImageLoadException(Exception):
19
+ def __init__(self, img_path):
20
+ super().__init__(f"Failed to load image: {img_path}")
21
+ self.img_path = img_path
22
+
23
+
24
+ def get_last_modified_file(file_prefix, exts, ext_fallback=None):
25
+ '''
26
+ get last modified file from files sharing same prefix
27
+ '''
28
+ latest_time = -1
29
+ latest_f = None
30
+ for ext in exts:
31
+ tmp_p = file_prefix + ext
32
+ if osp.exists(tmp_p) and osp.getmtime(tmp_p) > latest_time:
33
+ latest_time = osp.getmtime(tmp_p)
34
+ latest_f = tmp_p
35
+ if latest_f is None:
36
+ if ext_fallback is not None:
37
+ latest_f = file_prefix + ext_fallback
38
+ else:
39
+ latest_f = file_prefix + exts[0]
40
+ return latest_f
41
+
42
+
43
+ def write_jpg_metadata(imgpath: str, metadata="a metadata"):
44
+ exif_dict = {"Exif":{piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(metadata, encoding='unicode')}}
45
+ exif_bytes = piexif.dump(exif_dict)
46
+ piexif.insert(exif_bytes, imgpath)
47
+
48
+ def read_jpg_metadata(imgpath: str):
49
+ exif_dict = piexif.load(imgpath)
50
+ user_comment = piexif.helper.UserComment.load(exif_dict["Exif"][piexif.ExifIFD.UserComment])
51
+ bubdict = json.loads(user_comment)
52
+ return bubdict
53
+
54
+ page_start_pattern = re.compile(r'^###\s+', re.MULTILINE)
55
+ text_blkid_start_pattern = re.compile(r'^\d+\.', re.MULTILINE)
56
+
57
+ def parse_txt_translation(file_path: str):
58
+ with open(file_path, 'r', encoding='utf8') as f:
59
+ content = f.read()
60
+ page_start = None
61
+ page_list = []
62
+ for matched in page_start_pattern.finditer(content):
63
+ start, end = matched.span()
64
+ if page_start is not None:
65
+ page_list.append({'page_content': content[page_start: start]})
66
+ page_start = start
67
+ if page_start is not None:
68
+ page_list.append({'page_content': content[page_start:]})
69
+
70
+ for page_dict in page_list:
71
+ page_content = page_dict['page_content']
72
+ page_dict['page_name'] = page_start_pattern.sub('', page_content.split('\n')[0]).strip()
73
+ blkid_start = blkid_end = None
74
+ blk_list = []
75
+ for matched in text_blkid_start_pattern.finditer(page_content):
76
+ start, end = matched.span()
77
+ if blkid_start is not None:
78
+ blk_list.append(page_content[blkid_end: start].strip())
79
+ blkid_start = start
80
+ blkid_end = end
81
+ if blkid_start is not None:
82
+ blk_list.append(page_content[blkid_end:].strip())
83
+ page_dict['blk_list'] = blk_list
84
+
85
+ return page_list
86
+
87
+
88
+ class TextBlkEncoder(NumpyEncoder):
89
+ def default(self, obj):
90
+ if isinstance(obj, TextBlock):
91
+ return obj.to_dict()
92
+ elif isinstance(obj, FontFormat):
93
+ return vars(obj)
94
+ return NumpyEncoder.default(self, obj)
95
+
96
+
97
+ class ProjImgTrans:
98
+
99
+ def __init__(self, directory: str = None):
100
+ self.type = 'imgtrans'
101
+ self.directory: str = None
102
+ self.pages: Dict[str, List[TextBlock]] = {}
103
+ self._pagename2idx = {}
104
+ self._idx2pagename = {}
105
+
106
+ self._fuzzy_inpainted_list = None
107
+
108
+ self.not_found_pages: Dict[str, List[TextBlock]] = {}
109
+ self.new_pages: List[str] = []
110
+ self.proj_path: str = None
111
+
112
+ self.current_img: str = None
113
+ self.img_array: np.ndarray = None
114
+ self.mask_array: np.ndarray = None
115
+ self.inpainted_array: np.ndarray = None
116
+
117
+ # Watermark settings
118
+ self.enable_watermark = False
119
+ self.watermark_path = ""
120
+ self.watermark_opacity = 0.7
121
+
122
+ if directory is not None:
123
+ self.load(directory)
124
+
125
+ def idx2pagename(self, idx: int) -> str:
126
+ return self._idx2pagename[idx]
127
+
128
+ def pagename2idx(self, pagename: str) -> int:
129
+ if pagename in self.pages:
130
+ return self._pagename2idx[pagename]
131
+ return -1
132
+
133
+ def proj_name(self) -> str:
134
+ return self.type+'_'+osp.basename(self.directory)
135
+
136
+ def load(self, directory: str, json_path: str = None) -> bool:
137
+ self.directory = directory
138
+ if json_path is None:
139
+ self.proj_path = osp.join(self.directory, self.proj_name() + '.json')
140
+ else:
141
+ self.proj_path = json_path
142
+ new_proj = False
143
+ if not osp.exists(self.proj_path):
144
+ new_proj = True
145
+ self.new_project()
146
+ else:
147
+ try:
148
+ with open(self.proj_path, 'r', encoding='utf8') as f:
149
+ proj_dict = json.loads(f.read())
150
+ except Exception as e:
151
+ raise ProjectLoadFailureException(e)
152
+ self.load_from_dict(proj_dict)
153
+ if not osp.exists(self.inpainted_dir()):
154
+ os.makedirs(self.inpainted_dir())
155
+ if not osp.exists(self.mask_dir()):
156
+ os.makedirs(self.mask_dir())
157
+
158
+ # Fix: use self instead of proj and check proj_dict existence
159
+ if 'enable_watermark' in locals() and 'proj_dict' in locals() and 'enable_watermark' in proj_dict:
160
+ self.enable_watermark = proj_dict['enable_watermark']
161
+ if 'proj_dict' in locals() and 'watermark_path' in proj_dict:
162
+ self.watermark_path = proj_dict['watermark_path']
163
+ if 'proj_dict' in locals() and 'watermark_opacity' in proj_dict:
164
+ self.watermark_opacity = proj_dict['watermark_opacity']
165
+
166
+ return new_proj
167
+
168
+ def mask_dir(self):
169
+ return osp.join(self.directory, 'mask')
170
+
171
+ def inpainted_dir(self):
172
+ return osp.join(self.directory, 'inpainted')
173
+
174
+ def result_dir(self):
175
+ return osp.join(self.directory, 'result')
176
+
177
+ def load_from_dict(self, proj_dict: dict):
178
+ self.set_current_img(None)
179
+ try:
180
+ self.pages = {}
181
+ self._pagename2idx = {}
182
+ self._idx2pagename = {}
183
+ self.not_found_pages = {}
184
+ page_dict = proj_dict['pages']
185
+ not_found_pages = list(page_dict.keys())
186
+ found_pages = find_all_imgs(img_dir=self.directory, abs_path=False, sort=True)
187
+ for ii, imname in enumerate(found_pages):
188
+ if imname in page_dict:
189
+ self.pages[imname] = [TextBlock(**blk_dict) for blk_dict in page_dict[imname]]
190
+ not_found_pages.remove(imname)
191
+ else:
192
+ self.pages[imname] = []
193
+ self.new_pages.append(imname)
194
+ self._pagename2idx[imname] = ii
195
+ self._idx2pagename[ii] = imname
196
+ for imname in not_found_pages:
197
+ self.not_found_pages[imname] = [TextBlock(**blk_dict) for blk_dict in page_dict[imname]]
198
+ except Exception as e:
199
+ raise ProjectNotSupportedException(e)
200
+ set_img_failed = False
201
+ if 'current_img' in proj_dict:
202
+ current_img = proj_dict['current_img']
203
+ try:
204
+ self.set_current_img(current_img)
205
+ except (ImgnameNotInProjectException, RuntimeError) as e:
206
+ LOGGER.error(f"Failed to set current image {current_img}: {e}")
207
+ set_img_failed = True
208
+ else:
209
+ set_img_failed = True
210
+ LOGGER.warning(f'{current_img} not found.')
211
+ if set_img_failed:
212
+ if len(self.pages) > 0:
213
+ try:
214
+ self.set_current_img_byidx(0)
215
+ except RuntimeError as e:
216
+ LOGGER.error(f"Failed to set current image by index 0: {e}")
217
+
218
+ def load_translation_from_txt(self, file_path: str):
219
+ page_list = parse_txt_translation(file_path)
220
+ missing_pages = []
221
+ unmatched_pages = []
222
+ unexpected_pages = []
223
+ matched_pages = []
224
+ for page_dict in page_list:
225
+ page_name = page_dict['page_name']
226
+ if page_name in self.pages:
227
+ matched_pages.append(page_name)
228
+ else:
229
+ unexpected_pages.append(page_name)
230
+ continue
231
+ blklist = self.pages[page_name]
232
+ n_blk = len(blklist)
233
+ src_blk_list = page_dict['blk_list']
234
+ n_src_blk = len(src_blk_list)
235
+ if n_src_blk != n_blk:
236
+ LOGGER.warning(f'Unmatched text blocks in {page_name}, number of text blocks in this page vs source file: {n_blk}-{n_src_blk}')
237
+ unmatched_pages.append(page_name)
238
+ for blkid in range(min(n_blk, n_src_blk)):
239
+ blk = blklist[blkid]
240
+ blk.rich_text = ''
241
+ blk.translation = src_blk_list[blkid]
242
+
243
+ matched_pages = set(matched_pages)
244
+ if len(matched_pages) != self.num_pages:
245
+ for page_name in self.pages:
246
+ if page_name not in matched_pages:
247
+ missing_pages.append(page_name)
248
+
249
+ all_matched = len(missing_pages) == 0 and len(unmatched_pages) == 0 and len(unexpected_pages) == 0
250
+ return all_matched, {'missing_pages': missing_pages, 'unmatched_pages': unmatched_pages, 'unexpected_pages': unexpected_pages, 'matched_pages': matched_pages}
251
+
252
+ def load_from_json(self, json_path: str):
253
+ old_dir = self.directory
254
+ directory = osp.dirname(json_path)
255
+ try:
256
+ self.load(directory, json_path=json_path)
257
+ except Exception as e:
258
+ self.load(old_dir)
259
+ raise ProjectLoadFailureException(e)
260
+
261
+ def set_current_img(self, imgname: str):
262
+ if imgname is not None:
263
+ if imgname not in self.pages:
264
+ raise ImgnameNotInProjectException
265
+ self.current_img = imgname
266
+ img_path = self.current_img_path()
267
+ mask_path = self.get_mask_path(get_last_modified=True)
268
+ self.img_array = imread(img_path)
269
+ if self.img_array is None:
270
+ raise RuntimeError(f"Failed to load image: {img_path}")
271
+ im_h, im_w = self.img_array.shape[:2]
272
+ if osp.exists(mask_path):
273
+ self.mask_array = imread(mask_path, cv2.IMREAD_GRAYSCALE)
274
+ else:
275
+ self.mask_array = np.zeros((im_h, im_w), dtype=np.uint8)
276
+ self.inpainted_array = self.load_inpainted_by_imgname(imgname)
277
+ if self.inpainted_array is None:
278
+ self.inpainted_array = np.copy(self.img_array)
279
+ else:
280
+ self.current_img = None
281
+ self.img_array = None
282
+ self.mask_array = None
283
+ self.inpainted_array = None
284
+
285
+ def set_current_img_byidx(self, idx: int):
286
+ num_pages = self.num_pages
287
+ if idx < 0:
288
+ idx = idx + self.num_pages
289
+ if idx < 0 or idx > num_pages - 1:
290
+ self.set_current_img(None)
291
+ else:
292
+ self.set_current_img(self.idx2pagename(idx))
293
+
294
+ def get_blklist_byidx(self, idx: int) -> List[TextBlock]:
295
+ return self.pages[self.idx2pagename(idx)]
296
+
297
+ @property
298
+ def num_pages(self) -> int:
299
+ return len(self.pages)
300
+
301
+ @property
302
+ def current_idx(self) -> int:
303
+ return self.pagename2idx(self.current_img)
304
+
305
+ def new_project(self):
306
+ if not osp.exists(self.directory):
307
+ raise ProjectDirNotExistException
308
+ self.set_current_img(None)
309
+ imglist = find_all_imgs(self.directory, abs_path=False, sort=True)
310
+ self.pages = {}
311
+ self._pagename2idx = {}
312
+ self._idx2pagename = {}
313
+ for ii, imgname in enumerate(imglist):
314
+ self.pages[imgname] = []
315
+ self._pagename2idx[imgname] = ii
316
+ self._idx2pagename[ii] = imgname
317
+ self.set_current_img_byidx(0)
318
+ self.save()
319
+
320
+ def save(self):
321
+ if not osp.exists(self.directory):
322
+ raise ProjectDirNotExistException
323
+ with open(self.proj_path, "w", encoding="utf-8") as f:
324
+ f.write(json.dumps(self.to_dict(), ensure_ascii=False, cls=TextBlkEncoder))
325
+ LOGGER.debug(f'project saved to {self.proj_path}')
326
+
327
+ def to_dict(self) -> Dict:
328
+ pages = self.pages.copy()
329
+ pages.update(self.not_found_pages)
330
+ return {
331
+ 'directory': self.directory,
332
+ 'pages': pages,
333
+ 'current_img': self.current_img,
334
+ 'enable_watermark': self.enable_watermark,
335
+ 'watermark_path': self.watermark_path,
336
+ 'watermark_opacity': self.watermark_opacity
337
+ }
338
+
339
+ def read_img(self, imgname: str) -> np.ndarray:
340
+ if imgname not in self.pages:
341
+ raise ImgnameNotInProjectException
342
+ return imread(osp.join(self.directory, imgname))
343
+
344
+ def save_mask(self, img_name, mask: np.ndarray):
345
+ imwrite(self.get_mask_path(img_name), mask, ext=pcfg.intermediate_imgsave_ext)
346
+
347
+ def save_inpainted(self, img_name, inpainted: np.ndarray):
348
+ imwrite(self.get_inpainted_path(img_name), inpainted, ext=pcfg.intermediate_imgsave_ext)
349
+
350
+ def current_img_path(self) -> str:
351
+ if self.current_img is None:
352
+ return None
353
+ return osp.join(self.directory, self.current_img)
354
+
355
+ def get_mask_path(self, imgname: str = None, get_last_modified=False) -> str:
356
+ if imgname is None:
357
+ imgname = self.current_img
358
+
359
+ fileprefix = osp.join(self.mask_dir(), osp.splitext(imgname)[0])
360
+ if get_last_modified:
361
+ p = get_last_modified_file(fileprefix, ['.jxl', '.png'], ext_fallback=pcfg.intermediate_imgsave_ext)
362
+ else:
363
+ p = fileprefix+pcfg.intermediate_imgsave_ext
364
+
365
+ return p
366
+
367
+ def load_mask_by_imgname(self, imgname: str) -> np.ndarray:
368
+ mask = None
369
+ mp = self.get_mask_path(imgname, get_last_modified=True)
370
+ if osp.exists(mp):
371
+ mask = imread(mp, cv2.IMREAD_GRAYSCALE)
372
+ return mask
373
+
374
+ def get_inpainted_path(self, imgname: str = None, get_last_modified=False) -> str:
375
+ if imgname is None:
376
+ imgname = self.current_img
377
+
378
+ fileprefix = osp.join(self.inpainted_dir(), osp.splitext(imgname)[0])
379
+ if get_last_modified:
380
+ p = get_last_modified_file(fileprefix, ['.jxl', '.png'], ext_fallback=pcfg.intermediate_imgsave_ext)
381
+ else:
382
+ p = fileprefix+pcfg.intermediate_imgsave_ext
383
+
384
+ if not osp.exists(p) and shared.FUZZY_MATCH_IMAGE_NAME:
385
+ if self._fuzzy_inpainted_list is None:
386
+ if osp.exists(self.inpainted_dir()):
387
+ self._fuzzy_inpainted_list = find_all_imgs(self.inpainted_dir(), sort=True)
388
+ else:
389
+ self._fuzzy_inpainted_list = []
390
+ pidx = self.pagename2idx(imgname)
391
+ if pidx < len(self._fuzzy_inpainted_list):
392
+ return osp.join(self.inpainted_dir(), self._fuzzy_inpainted_list[pidx])
393
+ return p
394
+
395
+ def load_inpainted_by_imgname(self, imgname: str, scale_to_src: bool = True) -> np.ndarray:
396
+ inpainted = None
397
+ mp = self.get_inpainted_path(imgname, get_last_modified=True)
398
+ if mp is not None and osp.exists(mp):
399
+ inpainted = imread(mp)
400
+ if imgname == self.current_img and self.img_array is not None:
401
+ h, w = self.img_array.shape[:2]
402
+ else:
403
+ i = Image.open(osp.join(self.directory, imgname))
404
+ h, w = i.height, i.width
405
+ ih, iw = inpainted.shape[:2]
406
+ if ih != h or iw != w:
407
+ inpainted = Image.fromarray(inpainted).resize((w, h), resample=Image.Resampling.LANCZOS)
408
+ inpainted = np.array(inpainted)
409
+ return inpainted
410
+
411
+ def get_result_path(self, imgname: str) -> str:
412
+ ext = '.png'
413
+ if pcfg is not None:
414
+ if pcfg.imgsave_ext not in {'.jpg', '.png', '.webp', '.jxl'}:
415
+ LOGGER.warning('invalid image saving ext in config.json')
416
+ else:
417
+ ext = pcfg.imgsave_ext
418
+ return osp.join(self.result_dir(), osp.splitext(imgname)[0]+ext)
419
+
420
+
421
+ def backup(self):
422
+ raise NotImplementedError
423
+
424
+ @property
425
+ def is_empty(self):
426
+ return len(self.pages) == 0
427
+
428
+ @property
429
+ def is_all_pages_no_text(self):
430
+ return all([len(blklist) == 0 for blklist in self.pages.values()])
431
+
432
+ @property
433
+ def img_valid(self):
434
+ return self.img_array is not None
435
+
436
+ @property
437
+ def mask_valid(self):
438
+ return self.mask_array is not None
439
+
440
+ @property
441
+ def inpainted_valid(self):
442
+ return self.inpainted_array is not None
443
+
444
+ def set_next_img(self):
445
+ if self.current_img is not None:
446
+ next_idx = (self.current_idx + 1) % self.num_pages
447
+ self.set_current_img(self.idx2pagename(next_idx))
448
+
449
+ def set_prev_img(self):
450
+ if self.current_img is not None:
451
+ next_idx = (self.current_idx - 1 + self.num_pages) % self.num_pages
452
+ self.set_current_img(self.idx2pagename(next_idx))
453
+
454
+ def current_block_list(self) -> List[TextBlock]:
455
+ if self.current_img is not None:
456
+ assert self.current_img in self.pages
457
+ return self.pages[self.current_img]
458
+ else:
459
+ return None
460
+
461
+ def doc_path(self) -> str:
462
+ return os.path.join(self.directory, self.proj_name() + ".docx")
463
+
464
+ def doc_exist(self) -> bool:
465
+ return osp.exists(self.doc_path())
466
+
467
+ def dump_doc(self, delete_tmp_folder=True, fin_page_signal=None):
468
+
469
+ cuts_dir = os.path.join(self.directory, "bubcuts")
470
+ if os.path.exists(cuts_dir):
471
+ shutil.rmtree(cuts_dir)
472
+ os.mkdir(cuts_dir)
473
+
474
+ document = Document()
475
+ style = document.styles['Normal']
476
+ font = style.font
477
+ target_font = 'Arial'
478
+ font.name = target_font
479
+ for pagename, blklist in self.pages.items():
480
+ imgpath = os.path.join(self.directory, pagename)
481
+
482
+ cuts_path_list, cut_width_list = gen_ballon_cuts(cuts_dir, imgpath, blklist)
483
+ paragraph = document.add_paragraph(pagename)
484
+ paragraph.style = document.styles['Normal']
485
+ table = document.add_table(rows=len(cuts_path_list), cols=2, style='Table Grid')
486
+
487
+ for index, (cut_path, width) in enumerate(zip(cuts_path_list, cut_width_list)):
488
+ run = table.cell(index, 0).paragraphs[0].add_run()
489
+ run.style.font.name = target_font
490
+ blk: TextBlock = blklist[index]
491
+ bubdict = vars(blk).copy()
492
+ bubdict["imgkey"] = pagename
493
+ bubdict["rich_text"] = ''
494
+ bubdict["text"] = blk.get_text()
495
+ write_jpg_metadata(cut_path, metadata=json.dumps(bubdict, ensure_ascii=False, cls=TextBlkEncoder))
496
+ run.add_picture(cut_path, width=Inches(width/96 * 0.85))
497
+ table.cell(index, 1).text = bubdict["translation"]
498
+
499
+ document.add_page_break()
500
+
501
+ if fin_page_signal is not None:
502
+ fin_page_signal.emit()
503
+ # time.sleep(1)
504
+
505
+ doc_path = self.doc_path()
506
+ document.save(doc_path)
507
+ if delete_tmp_folder:
508
+ shutil.rmtree(cuts_dir)
509
+
510
+ def dump_txt_path(self, dump_target, suffix):
511
+ save_path = osp.join(self.directory, self.proj_name() + f'_{dump_target}{suffix}')
512
+ return save_path
513
+
514
+ def dump_txt(self, dump_target: str, suffix='.txt'):
515
+ save_path = self.dump_txt_path(dump_target, suffix=suffix)
516
+ text_all = []
517
+ assert dump_target in {'source', 'translation'}
518
+ assert suffix in {'.txt', '.md'}
519
+ for page_name, blk_list in self.pages.items():
520
+ text_in_page = ['### ' + page_name]
521
+ for ii, blk in enumerate(blk_list):
522
+ if dump_target == 'translation':
523
+ text = blk.translation.strip()
524
+ elif dump_target == 'source':
525
+ text = blk.get_text().strip()
526
+ text_in_page.append(f'{ii + 1}. {text}')
527
+ text_all.append('\n\n'.join(text_in_page))
528
+ with open(save_path, 'w', encoding='utf8') as f:
529
+ f.write('\n\n\n'.join(text_all))
530
+
531
+ def load_doc(self, doc_path, delete_tmp_folder=True, fin_page_signal=None):
532
+ tmp_bubble_folder = osp.join(self.directory, 'img_folder')
533
+ os.makedirs(tmp_bubble_folder, exist_ok=True)
534
+ docx2txt.process(doc_path, tmp_bubble_folder)
535
+
536
+ doc = docx.Document(doc_path)
537
+ body_xml_str = doc._body._element.xml
538
+
539
+ pages = {}
540
+ bub_index = 0
541
+ for tbl in re.findall(r'<w:tbl>(.*?)</w:tbl>', body_xml_str, re.DOTALL):
542
+ for tr in re.findall(r'<w:tr(.*?)>(.*?)</w:tr>', tbl, re.DOTALL):
543
+ if re.findall(r'<pic:cNvPr id=\"(.*?)\" name=\"(.*?)\"(.*?)>', tr[1]):
544
+ bub_index += 1
545
+ translation = ""
546
+ for paragraph in re.findall(r'<w:p(.*?)>(.*?)</w:p>', tr[1], re.DOTALL):
547
+ for wt in re.findall(r'<w:t>(.*?)</w:t>', paragraph[1], re.DOTALL):
548
+ translation += wt
549
+ translation += "\n"
550
+ translation = translation[:-1]
551
+ if len(translation) != 0 and translation[0] == "\n":
552
+ translation = translation[1:]
553
+
554
+
555
+ bubpath = os.path.join(tmp_bubble_folder, "image"+str(bub_index))
556
+ if osp.exists(bubpath+'.jpg'):
557
+ bubpath = bubpath + '.jpg'
558
+ else:
559
+ bubpath = bubpath + '.jpeg'
560
+
561
+ meta_dict = read_jpg_metadata(bubpath)
562
+ meta_dict["translation"] = translation
563
+ imgkey = meta_dict.pop("imgkey")
564
+ if not imgkey in pages:
565
+ pages[imgkey] = []
566
+ pages[imgkey].append(TextBlock(**meta_dict))
567
+
568
+ if fin_page_signal is not None:
569
+ fin_page_signal.emit()
570
+
571
+ self.merge_from_proj_dict(pages)
572
+ if delete_tmp_folder:
573
+ shutil.rmtree(tmp_bubble_folder)
574
+
575
+ def merge_from_proj_dict(self, tgt_dict: Dict) -> Dict:
576
+ if self.pages is None:
577
+ self.pages = {}
578
+ src_dict = self.pages if self.pages is not None else {}
579
+ key_lst = list(dict.fromkeys(list(src_dict.keys()) + list(tgt_dict.keys())))
580
+ key_lst.sort()
581
+ rst_dict = {}
582
+ pagename2idx = {}
583
+ idx2pagename = {}
584
+ page_counter = 0
585
+ for key in key_lst:
586
+ if key in src_dict and not key in tgt_dict:
587
+ rst_dict[key] = src_dict[key]
588
+ else:
589
+ rst_dict[key] = tgt_dict[key]
590
+ pagename2idx[key] = page_counter
591
+ idx2pagename[page_counter] = key
592
+ page_counter += 1
593
+ self.pages.clear()
594
+ self.pages.update(rst_dict)
595
+ self._pagename2idx = pagename2idx
596
+ self._idx2pagename = idx2pagename
597
+
598
+
599
+ def gen_ballon_cuts(cuts_dir: str, imgpath: str, blk_list: List[TextBlock], resize=True) -> Tuple[List[str], List[int]]:
600
+ img = imread(imgpath)
601
+ imgname = os.path.basename(imgpath)
602
+ cuts_path_list = []
603
+ cut_width_list = []
604
+ for ii, blk in enumerate(blk_list):
605
+
606
+ x, y, w, h = blk.bounding_rect()
607
+ x, y = max(x, 0), max(y, 0)
608
+ w = max(w, 1)
609
+ h = max(h, 1)
610
+ x1, y1, x2, y2 = int(x), int(y), int(x+w), int(y+h)
611
+
612
+ cut_path = os.path.join(cuts_dir, f'{imgname}-{ii}.jpg')
613
+ bub = img[y1:y2, x1:x2]
614
+ max_width = 448
615
+
616
+ if bub.shape[0] < 1 or bub.shape[1] < 1:
617
+ emptyw = 60
618
+ resized = np.full((emptyw, emptyw, 3), fill_value=0, dtype=np.uint8)
619
+ width = emptyw
620
+ else:
621
+ # scale_percent = 60 # percent of original size
622
+ scale_percent = min(1920 / img.shape[0], max_width / w)
623
+
624
+ if scale_percent < 1:
625
+ width = max(1, int(bub.shape[1] * scale_percent))
626
+ height = max(1, int(bub.shape[0] * scale_percent))
627
+ dim = (width, height)
628
+ resized = cv2.resize(bub, dim, interpolation = cv2.INTER_AREA) if resize else bub
629
+ else:
630
+ width = w
631
+ resized = bub
632
+
633
+ imwrite(cut_path, resized, '.jpg')
634
+ cuts_path_list.append(cut_path)
635
+ cut_width_list.append(width)
636
+
637
+ return cuts_path_list, cut_width_list
638
+
639
+
640
+ def save_image_with_watermark(
641
+ img: 'Union[np.ndarray, Image.Image]',
642
+ output_path: str,
643
+ watermark_path: str = None,
644
+ watermark_opacity: float = 0.7,
645
+ quality: int = 95
646
+ ) -> bool:
647
+ """Save image with optional watermark applied."""
648
+ try:
649
+ import cv2
650
+ from PIL import Image
651
+ if isinstance(img, np.ndarray):
652
+ if img.ndim == 2:
653
+ img_pil = Image.fromarray(img)
654
+ elif img.shape[2] == 4:
655
+ img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA))
656
+ else:
657
+ img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
658
+ else:
659
+ img_pil = img
660
+
661
+ import os.path as osp
662
+ from utils.watermark_utils import apply_watermark_to_pil_image
663
+
664
+ if watermark_path and osp.exists(watermark_path):
665
+ img_pil = apply_watermark_to_pil_image(img_pil, watermark_path, watermark_opacity)
666
+
667
+ ext = osp.splitext(output_path)[1].lower()
668
+ img_format = "PNG"
669
+ if ext in ['.jpg', '.jpeg']:
670
+ img_format = "JPEG"
671
+ if img_pil.mode == 'RGBA':
672
+ img_pil = img_pil.convert('RGB')
673
+ elif ext == '.webp':
674
+ img_format = "WEBP"
675
+ elif ext == '.jxl':
676
+ img_format = "JPEG2000"
677
+
678
+ save_kwargs = {'format': img_format}
679
+ if img_format in ['JPEG', 'WEBP', 'JPEG2000']:
680
+ save_kwargs['quality'] = quality
681
+ elif img_format == 'PNG':
682
+ save_kwargs['compress_level'] = 3
683
+
684
+ img_pil.save(output_path, **save_kwargs)
685
+ return True
686
+ except Exception as e:
687
+ LOGGER.error(f"Error saving image with watermark: {str(e)}")
688
+ return False
689
+
690
+
691
+ def save_result(self, imgname: str, img: np.ndarray) -> bool:
692
+ output_path = self.get_result_path(imgname)
693
+
694
+ if self.watermark_enabled and self.watermark_path:
695
+ # تحويل إلى صورة PIL
696
+ if img.ndim == 2:
697
+ img_pil = Image.fromarray(img)
698
+ elif img.shape[2] == 3:
699
+ img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
700
+ elif img.shape[2] == 4:
701
+ img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA))
702
+
703
+ # تطبيق العلامة المائية
704
+ img_pil = apply_watermark_to_pil_image(
705
+ img_pil,
706
+ self.watermark_path,
707
+ self.watermark_opacity
708
+ )
709
+
710
+ # التحويل مرة أخرى إلى numpy array
711
+ if img_pil.mode == 'RGB':
712
+ img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
713
+ elif img_pil.mode == 'RGBA':
714
+ img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGBA2BGRA)
715
+ else:
716
+ img = np.array(img_pil)
717
+
718
+ return imwrite(output_path, img)
719
+
720
+
utils/registry.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py
2
+
3
+ import inspect
4
+ import warnings
5
+ from functools import partial
6
+
7
+ class Registry:
8
+ """A registry to map strings to classes.
9
+
10
+ Registered object could be built from registry.
11
+
12
+ Example:
13
+ >>> MODELS = Registry('models')
14
+ >>> @MODELS.register_module()
15
+ >>> class ResNet:
16
+ >>> pass
17
+ >>> resnet = MODELS.build(dict(type='ResNet'))
18
+
19
+ Please refer to
20
+ https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
21
+ advanced usage.
22
+
23
+ Args:
24
+ name (str): Registry name.
25
+ build_func(func, optional): Build function to construct instance from
26
+ Registry, func:`build_from_cfg` is used if neither ``parent`` or
27
+ ``build_func`` is specified. If ``parent`` is specified and
28
+ ``build_func`` is not given, ``build_func`` will be inherited
29
+ from ``parent``. Default: None.
30
+ parent (Registry, optional): Parent registry. The class registered in
31
+ children registry could be built from parent. Default: None.
32
+ scope (str, optional): The scope of registry. It is the key to search
33
+ for children registry. If not specified, scope will be the name of
34
+ the package where class is defined, e.g. mmdet, mmcls, mmseg.
35
+ Default: None.
36
+ """
37
+
38
+ def __init__(self, name, build_func=None, parent=None, scope=None):
39
+ self._name = name
40
+ self._module_dict = dict()
41
+ self._children = dict()
42
+
43
+ # self._scope = self.infer_scope() if scope is None else scope
44
+
45
+ # self.build_func will be set with the following priority:
46
+ # 1. build_func
47
+ # 2. parent.build_func
48
+ # 3. build_from_cfg
49
+ # if build_func is None:
50
+ # if parent is not None:
51
+ # self.build_func = parent.build_func
52
+ # else:
53
+ # self.build_func = build_from_cfg
54
+ # else:
55
+ # self.build_func = build_func
56
+ if parent is not None:
57
+ assert isinstance(parent, Registry)
58
+ parent._add_children(self)
59
+ self.parent = parent
60
+ else:
61
+ self.parent = None
62
+
63
+ def __len__(self):
64
+ return len(self._module_dict)
65
+
66
+ def __contains__(self, key):
67
+ return self.get(key) is not None
68
+
69
+ def __repr__(self):
70
+ format_str = self.__class__.__name__ + \
71
+ f'(name={self._name}, ' \
72
+ f'items={self._module_dict})'
73
+ return format_str
74
+
75
+ @staticmethod
76
+ def infer_scope():
77
+ """Infer the scope of registry.
78
+
79
+ The name of the package where registry is defined will be returned.
80
+
81
+ Example:
82
+ >>> # in mmdet/models/backbone/resnet.py
83
+ >>> MODELS = Registry('models')
84
+ >>> @MODELS.register_module()
85
+ >>> class ResNet:
86
+ >>> pass
87
+ The scope of ``ResNet`` will be ``mmdet``.
88
+
89
+ Returns:
90
+ str: The inferred scope name.
91
+ """
92
+ # inspect.stack() trace where this function is called, the index-2
93
+ # indicates the frame where `infer_scope()` is called
94
+ filename = inspect.getmodule(inspect.stack()[2][0]).__name__
95
+ split_filename = filename.split('.')
96
+ return split_filename[0]
97
+
98
+ @staticmethod
99
+ def split_scope_key(key):
100
+ """Split scope and key.
101
+
102
+ The first scope will be split from key.
103
+
104
+ Examples:
105
+ >>> Registry.split_scope_key('mmdet.ResNet')
106
+ 'mmdet', 'ResNet'
107
+ >>> Registry.split_scope_key('ResNet')
108
+ None, 'ResNet'
109
+
110
+ Return:
111
+ tuple[str | None, str]: The former element is the first scope of
112
+ the key, which can be ``None``. The latter is the remaining key.
113
+ """
114
+ split_index = key.find('.')
115
+ if split_index != -1:
116
+ return key[:split_index], key[split_index + 1:]
117
+ else:
118
+ return None, key
119
+
120
+ @property
121
+ def name(self):
122
+ return self._name
123
+
124
+ # @property
125
+ # def scope(self):
126
+ # return self._scope
127
+
128
+ @property
129
+ def module_dict(self):
130
+ return self._module_dict
131
+
132
+ @property
133
+ def children(self):
134
+ return self._children
135
+
136
+ def get(self, key):
137
+ """Get the registry record.
138
+
139
+ Args:
140
+ key (str): The class name in string format.
141
+
142
+ Returns:
143
+ class: The corresponding class.
144
+ """
145
+ scope, real_key = self.split_scope_key(key)
146
+ if scope is None or scope == self._scope:
147
+ # get from self
148
+ if real_key in self._module_dict:
149
+ return self._module_dict[real_key]
150
+ else:
151
+ # get from self._children
152
+ if scope in self._children:
153
+ return self._children[scope].get(real_key)
154
+ else:
155
+ # goto root
156
+ parent = self.parent
157
+ while parent.parent is not None:
158
+ parent = parent.parent
159
+ return parent.get(key)
160
+
161
+ # def build(self, *args, **kwargs):
162
+ # return self.build_func(*args, **kwargs, registry=self)
163
+
164
+ def _add_children(self, registry):
165
+ """Add children for a registry.
166
+
167
+ The ``registry`` will be added as children based on its scope.
168
+ The parent registry could build objects from children registry.
169
+
170
+ Example:
171
+ >>> models = Registry('models')
172
+ >>> mmdet_models = Registry('models', parent=models)
173
+ >>> @mmdet_models.register_module()
174
+ >>> class ResNet:
175
+ >>> pass
176
+ >>> resnet = models.build(dict(type='mmdet.ResNet'))
177
+ """
178
+
179
+ assert isinstance(registry, Registry)
180
+ assert registry.scope is not None
181
+ assert registry.scope not in self.children, \
182
+ f'scope {registry.scope} exists in {self.name} registry'
183
+ self.children[registry.scope] = registry
184
+
185
+ def _register_module(self, module_class, module_name=None, force=False):
186
+ if not inspect.isclass(module_class):
187
+ raise TypeError('module must be a class, '
188
+ f'but got {type(module_class)}')
189
+
190
+ if module_name is None:
191
+ module_name = module_class.__name__
192
+ if isinstance(module_name, str):
193
+ module_name = [module_name]
194
+
195
+ for name in module_name:
196
+ if not force and name in self._module_dict:
197
+ raise KeyError(f'{name} is already registered '
198
+ f'in {self.name}')
199
+ self._module_dict[name] = module_class
200
+
201
+
202
+ def deprecated_register_module(self, cls=None, force=False):
203
+ warnings.warn(
204
+ 'The old API of register_module(module, force=False) '
205
+ 'is deprecated and will be removed, please use the new API '
206
+ 'register_module(name=None, force=False, module=None) instead.',
207
+ DeprecationWarning)
208
+ if cls is None:
209
+ return partial(self.deprecated_register_module, force=force)
210
+ self._register_module(cls, force=force)
211
+ return cls
212
+
213
+ def register_module(self, name=None, force=False, module=None):
214
+ """Register a module.
215
+
216
+ A record will be added to `self._module_dict`, whose key is the class
217
+ name or the specified name, and value is the class itself.
218
+ It can be used as a decorator or a normal function.
219
+
220
+ Example:
221
+ >>> backbones = Registry('backbone')
222
+ >>> @backbones.register_module()
223
+ >>> class ResNet:
224
+ >>> pass
225
+
226
+ >>> backbones = Registry('backbone')
227
+ >>> @backbones.register_module(name='mnet')
228
+ >>> class MobileNet:
229
+ >>> pass
230
+
231
+ >>> backbones = Registry('backbone')
232
+ >>> class ResNet:
233
+ >>> pass
234
+ >>> backbones.register_module(ResNet)
235
+
236
+ Args:
237
+ name (str | None): The module name to be registered. If not
238
+ specified, the class name will be used.
239
+ force (bool, optional): Whether to override an existing class with
240
+ the same name. Default: False.
241
+ module (type): Module class to be registered.
242
+ """
243
+ if not isinstance(force, bool):
244
+ raise TypeError(f'force must be a boolean, but got {type(force)}')
245
+ # NOTE: This is a walkaround to be compatible with the old api,
246
+ # while it may introduce unexpected bugs.
247
+ if isinstance(name, type):
248
+ return self.deprecated_register_module(name, force=force)
249
+
250
+ # raise the error ahead of time
251
+ if not (name is None or isinstance(name, str)):
252
+ raise TypeError(
253
+ 'name must be either of None, an instance of str or a sequence'
254
+ f' of str, but got {type(name)}')
255
+
256
+ # use it as a normal method: x.register_module(module=SomeClass)
257
+ if module is not None:
258
+
259
+ self._register_module(
260
+ module_class=module, module_name=name, force=force)
261
+ return module
262
+
263
+ # use it as a decorator: @x.register_module()
264
+ def _register(cls):
265
+ self._register_module(
266
+ module_class=cls, module_name=name, force=force)
267
+ return cls
268
+
269
+ return _register
270
+
271
+ def __getitem__(self, key: str):
272
+ return self.get(key)
utils/shared.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+ import os
3
+ import os.path as osp
4
+ import json
5
+ import sys
6
+
7
+ ICON_PATH = 'icons/icon.icns'
8
+
9
+ PROGRAM_PATH = osp.abspath(osp.dirname(osp.dirname(__file__)))
10
+ LOGGING_PATH = osp.join(PROGRAM_PATH, 'logs')
11
+
12
+ LIBS_PATH = osp.join(PROGRAM_PATH, 'data/libs')
13
+
14
+ STYLESHEET_PATH = osp.join(PROGRAM_PATH, 'config/stylesheet.css')
15
+ THEME_PATH = osp.join(PROGRAM_PATH, 'config/themes.json')
16
+ CONFIG_PATH = osp.join(PROGRAM_PATH, 'config/config.json')
17
+
18
+ DEFAULT_TEXTSTYLE_DIR = osp.join(PROGRAM_PATH, 'config/textstyles')
19
+ if not osp.exists(DEFAULT_TEXTSTYLE_DIR):
20
+ os.makedirs(DEFAULT_TEXTSTYLE_DIR)
21
+
22
+ st_manager = None
23
+ CONFIG_FONTSIZE_HEADER = 18
24
+ CONFIG_FONTSIZE_TABLE = 16
25
+ CONFIG_FONTSIZE_CONTENT = 16
26
+
27
+ CONFIG_COMBOBOX_HEIGHT = 30
28
+ CONFIG_COMBOBOX_SHORT = 200
29
+ CONFIG_COMBOBOX_MIDEAN = 332
30
+ CONFIG_COMBOBOX_LONG = 468
31
+
32
+ _size2width = {
33
+ 'short': CONFIG_COMBOBOX_SHORT,
34
+ 'median': CONFIG_COMBOBOX_MIDEAN,
35
+ 'long':CONFIG_COMBOBOX_LONG
36
+ }
37
+
38
+ def size2width(size: str):
39
+ global _size2width
40
+ return _size2width[size]
41
+
42
+ HORSLIDER_FIXHEIGHT = 36
43
+
44
+ WIDGET_SPACING_CLOSE = 8
45
+ TEXTEDIT_FIXWIDTH = 350
46
+
47
+ TEXTEFFECT_FIXWIDTH = 400
48
+ TEXTEFFECT_MAXHEIGHT = 500
49
+
50
+ LEFTBAR_WIDTH = 48
51
+ LEFTBTN_WIDTH = 28
52
+
53
+ LDPI = 96.
54
+ DPI = 188.75
55
+
56
+ SCREEN_H = 2160
57
+ SCREEN_W = 3840
58
+
59
+ DEFAULT_FONT_FAMILY = 'Microsoft YaHei UI'
60
+ APP_DEFAULT_FONT = 'Microsoft YaHei UI'
61
+
62
+ WINDOW_BORDER_WIDTH = 4
63
+ BOTTOMBAR_HEIGHT = 32
64
+ TITLEBAR_HEIGHT = 30
65
+
66
+ PAGELIST_THUMBNAIL_MAXNUM = 100
67
+ PAGELIST_THUMBNAIL_SIZE = 48
68
+
69
+ FLAG_QT6 = True
70
+
71
+ SLIDERHANDLE_COLOR = (85,85,96)
72
+ FOREGROUND_FONTCOLOR = (93,93,95)
73
+
74
+ MAX_NUM_LOG = 7
75
+
76
+ TRANSLATE_DIR = osp.join(PROGRAM_PATH, 'translate')
77
+ DISPLAY_LANGUAGE_MAP = {
78
+ "English": "English",
79
+ "简体中文": "zh_CN",
80
+ "Русский": "ru_RU",
81
+ "Português (Brasil)": "pt_BR",
82
+ "한국어": "ko_KR",
83
+ "Español": "es_MX",
84
+ "Hungarian": "hu_HU"
85
+ }
86
+ VALID_LANG_SET = set(list(DISPLAY_LANGUAGE_MAP.values()))
87
+
88
+ for p in os.listdir(TRANSLATE_DIR):
89
+ if p.endswith('.qm'):
90
+ lang = p.replace('.qm', '')
91
+ if lang not in VALID_LANG_SET:
92
+ DISPLAY_LANGUAGE_MAP[lang] = lang
93
+
94
+ DEFAULT_DISPLAY_LANG = 'English'
95
+
96
+ USE_PYSIDE6 = False
97
+ ON_MACOS = sys.platform == 'darwin'
98
+ ON_WINDOWS = sys.platform == 'win32'
99
+ HEADLESS = False
100
+ DEBUG = False
101
+ args = None
102
+
103
+ FUZZY_MATCH_IMAGE_NAME = False
104
+
105
+ cache_data: Dict = None
106
+ cache_dir: str = osp.join(PROGRAM_PATH, '.btrans_cache')
107
+ cache_path: str = osp.join(PROGRAM_PATH, '.btrans_cache/cache.json')
108
+ CACHE_UPDATED = False
109
+ check_local_file_hash = True
110
+
111
+ FONT_FAMILIES: set = None
112
+ CUSTOM_FONTS = []
113
+ pbar = {}
114
+ runtime_widget_set = set()
115
+
116
+ def add_to_runtime_widget_set(widget):
117
+ runtime_widget_set.add(widget)
118
+
119
+ def remove_from_runtime_widget_set(widget):
120
+ if widget in runtime_widget_set:
121
+ runtime_widget_set.remove(widget)
122
+
123
+ showed_exception = set()
124
+
125
+ # it will be set to ui.mainwindow.create_errdialog.emit after UI initialized
126
+ create_errdialog_in_mainthread = lambda *args, **kwargs: None
127
+
128
+ create_infodialog_in_mainthread = lambda *args, **kwargs: None
129
+
130
+ def load_cache():
131
+ global cache_data
132
+ if cache_data is None:
133
+ if osp.exists(cache_path):
134
+ try:
135
+ with open(cache_path, "r", encoding="utf8") as file:
136
+ cache_data = json.load(file)
137
+ except:
138
+ print(f'cached file {cache_path} is invalid')
139
+ cache_data = {}
140
+ else:
141
+ cache_data = {}
142
+
143
+ def dump_cache():
144
+ global cache_data
145
+ if cache_data is None:
146
+ return
147
+
148
+ cache_dir = osp.dirname(cache_path)
149
+ if not osp.exists(cache_dir):
150
+ os.makedirs(cache_dir)
151
+
152
+ with open(cache_path, "w", encoding="utf8") as file:
153
+ json.dump(cache_data, file, indent=4)
154
+
155
+ global CACHE_UPDATED
156
+ CACHE_UPDATED = False
157
+
158
+ config_name_to_view_widget = {}
159
+ action_to_view_config_name = {}
160
+ register_view_widget: lambda *args, **kwargs: None
utils/split_text_region.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2, os, re, random
2
+ import numpy as np
3
+ # import tesserocr
4
+ # from tesserocr import PyTessBaseAPI, PSM, OEM
5
+
6
+
7
+
8
+ class TextSpan(object):
9
+ def __init__(self, top_bnd=None, bottom_bnd=None, left_bnd=None, right_bnd=None):
10
+ self.top = top_bnd
11
+ self.bottom = bottom_bnd
12
+ self.height = self.bottom - self.top if bottom_bnd is not None else None
13
+
14
+ self.left = left_bnd
15
+ self.right = right_bnd
16
+ self.width = self.right - self.left if right_bnd is not None else None
17
+
18
+ def set_top(self, top_bnd):
19
+ self.top = top_bnd
20
+ return True
21
+
22
+ def set_bottom(self, bottom_bnd):
23
+ if self.top is None or bottom_bnd <= self.top:
24
+ return False
25
+ self.bottom = bottom_bnd
26
+ self.height = self.bottom - self.top
27
+ return True
28
+
29
+ def set_left(self, left_bnd):
30
+ self.left = left_bnd
31
+ return True
32
+
33
+ def set_right(self, right_bnd):
34
+ if self.left is None or right_bnd <= self.left:
35
+ return False
36
+ self.right = right_bnd
37
+ self.width = right_bnd - self.left
38
+ return True
39
+
40
+ def __getitem__(self, index):
41
+ if isinstance(index, int) and index >=0 and index < 4:
42
+ return [self.left, self.top, self.right, self.bottom][index]
43
+ else:
44
+ raise AttributeError(f'Invalid key: {index}')
45
+
46
+ def split_step0(span, thresh, sumby_yaxis, thresh2=None) -> list[TextSpan]:
47
+ candidate_pnts = (np.where(sumby_yaxis[span.top: span.bottom] > thresh)[0] + span.top).tolist()
48
+ span_list = []
49
+ if len(candidate_pnts) == 0:
50
+ return None
51
+ stride_tol = 1
52
+ span0, span1 = TextSpan(candidate_pnts[0]), TextSpan()
53
+ for pnt_ind in range(len(candidate_pnts)-1):
54
+ if candidate_pnts[pnt_ind+1] - candidate_pnts[pnt_ind] > stride_tol:
55
+ if not span0.set_bottom(candidate_pnts[pnt_ind]):
56
+ continue
57
+ span_list = split_step1(span0, span_list, thresh=thresh2, sumby_yaxis=sumby_yaxis)
58
+ span1.set_top(candidate_pnts[pnt_ind+1])
59
+ span0 = span1
60
+ span1 = TextSpan()
61
+
62
+ if len(candidate_pnts)-1 == 0:
63
+ if candidate_pnts[0] == candidate_pnts[-1]:
64
+ span_list = None
65
+ else:
66
+ span0 = TextSpan(candidate_pnts[0], candidate_pnts[-1])
67
+ span_list = split_step1(span0, span_list, thresh=thresh2, sumby_yaxis=sumby_yaxis)
68
+ elif span0.top != candidate_pnts[-1]:
69
+ span0.set_bottom(candidate_pnts[-1])
70
+ span_list = split_step1(span0, span_list, thresh=thresh2, sumby_yaxis=sumby_yaxis)
71
+
72
+ return span_list
73
+
74
+
75
+
76
+ def split_step1(span, span_list, thresh=None, sumby_yaxis=None):
77
+ if thresh is None:
78
+ span_list.append(span)
79
+ return span_list
80
+ else:
81
+ subspan_list = split_step0(span, thresh, sumby_yaxis)
82
+ # print(np.var(sumby_yaxis[span.top:span.bottom]))
83
+ if subspan_list is not None:
84
+
85
+ _, maxspan = find_span(subspan_list, max)
86
+ _, minspan = find_span(subspan_list, min)
87
+
88
+ sum_height = sum(c.height for c in subspan_list)
89
+
90
+ if maxspan.height / minspan.height > 2.5 or sum_height / span.height < 0.3 or len(subspan_list) == 1:
91
+ subspan_list = None
92
+ if subspan_list is not None and len(subspan_list) > 1:
93
+ span_list += subspan_list
94
+ else:
95
+ span_list.append(span)
96
+ return span_list
97
+
98
+
99
+
100
+ def shrink_span_list(src_img, span_list, shrink_vert_space=True, shrink_hor_space=True):
101
+ height, width = src_img.shape[0], src_img.shape[1]
102
+
103
+ sum_spacing = 0
104
+ if shrink_vert_space:
105
+ for ii in range(len(span_list)-1):
106
+ line_spacing = span_list[ii+1].top - span_list[ii].bottom
107
+ sum_spacing += line_spacing
108
+ line_spacing = int(round(line_spacing / 2))
109
+ span_list[ii+1].top -= line_spacing
110
+ span_list[ii].set_bottom(span_list[ii].bottom + line_spacing)
111
+
112
+ if len(span_list) >= 2:
113
+ mean_spacing = int(0.5 * round(sum_spacing / (len(span_list)-1)))
114
+ span_list[0].top = max(0, span_list[0].top-mean_spacing)
115
+ span_list[0].set_bottom(span_list[0].bottom)
116
+ span_list[-1].set_bottom(min(src_img.shape[0], span_list[-1].bottom))
117
+
118
+ left_var, middle_var = -1, -1
119
+ if shrink_hor_space:
120
+ left_pnts, middle_pnts = [], []
121
+ for ii in range(len(span_list)):
122
+ s = span_list[ii]
123
+ im = src_img[s.top: s.bottom, 0: width]
124
+ sumby_yaxis = np.mean(im, axis=0)
125
+ content_array = np.where(sumby_yaxis > 10)[0].tolist()
126
+ left, right = 0, width
127
+ if len(content_array) != 0:
128
+ left, right = content_array[0], content_array[-1]
129
+ span_list[ii].set_left(left)
130
+ span_list[ii].set_right(right)
131
+ s = span_list[ii]
132
+ left_pnts.append(left)
133
+ middle_pnts.append((left+right)/2)
134
+ left_var, middle_var = np.var(np.array(left_pnts)), np.var(np.array(middle_pnts))
135
+
136
+ return span_list, (left_var, middle_var)
137
+
138
+
139
+
140
+ def find_span(span_list, max_or_min=max, key="height"):
141
+ if key=="height":
142
+ return max_or_min(enumerate(span_list), key=(lambda x: span_list[x[0]].height), default = -1)
143
+ else:
144
+ return max_or_min(enumerate(span_list), key=(lambda x: span_list[x[0]].width), default = -1)
145
+
146
+
147
+
148
+ def discard_spans(span_list, thresh_ratio=0.3):
149
+ index, max_span = find_span(span_list, max)
150
+ max_height = max_span.height
151
+ height_thresh = max_height * thresh_ratio
152
+ new_spanlist = []
153
+ for sp in span_list:
154
+ if sp.height < height_thresh:
155
+ continue
156
+ new_spanlist.append(sp)
157
+
158
+ return new_spanlist
159
+
160
+
161
+
162
+ def plot_mapresult(sumbyvector, xlength, span_list=None, thresh=None):
163
+ '''for experiment'''
164
+ try:
165
+ import matplotlib.pyplot as plt
166
+ plt.plot(sumbyvector)
167
+ plt.ylabel('div pnt value')
168
+ plt.xlabel('div pnt coord')
169
+ s = [0, 255]
170
+ x_cords = []
171
+ if span_list is not None:
172
+ for sp in span_list:
173
+ x_cords.append(sp.top)
174
+ x_cords.append(sp.bottom)
175
+ if thresh is not None:
176
+ for tr in thresh:
177
+ plt.vlines(x = x_cords, ymin = 0, ymax = max(s),
178
+ colors = 'purple',
179
+ label = 'vline_multiple - full height')
180
+ plt.hlines(y = tr * sumbyvector.mean(), xmin = 0, xmax = xlength, linestyles='--')
181
+ plt.show()
182
+ except:
183
+ pass
184
+
185
+
186
+
187
+ def box(width, height):
188
+ return np.ones((height, width), dtype=np.uint8)
189
+
190
+
191
+ def crop_img(img, crop_ratio=0.2, clip_width=True, dilate=False):
192
+ h, w = img.shape[:2]
193
+ moments = cv2.moments(img)
194
+ area = moments['m00']
195
+ if area != 0:
196
+ mean_x = int(round(moments['m10'] / area))
197
+ mean_y = int(round(moments['m01'] / area))
198
+ crop_r = int(round(crop_ratio * w))
199
+ if clip_width:
200
+ crop_x0 = np.clip(mean_x - crop_r, 0, w)
201
+ crop_x1 = np.clip(mean_x + crop_r, 0, w)
202
+ if crop_x1 > crop_x0:
203
+ img = img[:, crop_x0: crop_x1]
204
+ else:
205
+ crop_r = np.clip(crop_r * 2, 0, w - 1)
206
+ img = img[:, crop_r:]
207
+ img = np.copy(img)
208
+ if clip_width and dilate:
209
+ w = int(round(w/7))
210
+ if w > 1:
211
+ img = cv2.dilate(img, box(w, 1), 1)
212
+ return img, img.shape[0], img.shape[1]
213
+
214
+
215
+
216
+ def split_textblock(src_img, crop_ratio=0.2, blur=False, show_process=False, discard=True, shrink=True, recheck=False, clip_width=True, dilate=True):
217
+
218
+ if blur:
219
+ src_img = cv2.GaussianBlur(src_img,(3,3),cv2.BORDER_DEFAULT)
220
+ if crop_ratio > 0:
221
+ img, height, width = crop_img(src_img, crop_ratio=crop_ratio, clip_width=clip_width, dilate=dilate)
222
+ else:
223
+ img, height, width = src_img, src_img.shape[0], src_img.shape[1]
224
+
225
+ sumby_yaxis = img.mean(axis=1)
226
+ bound0 = np.where(sumby_yaxis > sumby_yaxis.mean() * 0.1)[0].tolist()
227
+ vars = (-1, -1)
228
+
229
+ if len(bound0) < 2:
230
+ return [TextSpan(0, height-1, 0, width - 1)], vars
231
+
232
+ base_span = TextSpan(bound0[0], bound0[-1])
233
+ meanby_yaxis = sumby_yaxis.mean()
234
+
235
+ thresh_ratio = [0.4, 0.8]
236
+ thresh0 = meanby_yaxis * thresh_ratio[0]
237
+ thresh2 = meanby_yaxis * thresh_ratio[1]
238
+
239
+ span_list = split_step0(base_span, thresh0, sumby_yaxis, thresh2=thresh2)
240
+ if span_list is None:
241
+ return None, None
242
+ if discard:
243
+ span_list = discard_spans(span_list)
244
+ if shrink:
245
+ span_list, vars = shrink_span_list(src_img, span_list)
246
+
247
+ '''for experiment'''
248
+ if show_process:
249
+ plot_mapresult(sumby_yaxis, height, span_list=span_list, thresh=thresh_ratio)
250
+
251
+ if recheck and len(span_list) == 1 and crop_ratio > 0:
252
+ return split_textblock(src_img, crop_ratio==-1, show_process=show_process, discard=discard, shrink=shrink, recheck=False)
253
+
254
+ valid_span_list = []
255
+ for span in span_list:
256
+ if span.top is None:
257
+ span.set_top(0)
258
+ if span.left is None:
259
+ span.set_left(0)
260
+ if span.right is None:
261
+ span.set_right(width)
262
+ if span.bottom is None:
263
+ span.set_bottom(height)
264
+ valid_span_list.append(span)
265
+
266
+ return valid_span_list, vars
267
+
268
+
269
+
270
+ # def tessocr_img2text(img, lang):
271
+ # img = Image.fromarray(img)
272
+ # if re.findall("vert", lang):
273
+ # psm = PSM.SINGLE_BLOCK_VERT_TEXT
274
+ # else:
275
+ # psm = PSM.SINGLE_LINE
276
+ # return tesserocr.image_to_text(img, psm=psm, lang=lang, path=TESSDATA_PATH)
277
+
278
+ # def tessocr_img2text(img, lang):
279
+ # psm = "5" if re.findall("vert", lang) else "7"
280
+ # config = r'--tessdata-dir "models\tessdata" --psm ' + psm
281
+ # return pytesseract.image_to_string(img, lang=lang, config=config)
282
+
283
+
284
+ def textspan2list(span_list):
285
+ converted_list = []
286
+ for ii, s in enumerate(span_list):
287
+ converted_list.append([])
288
+ converted_list[ii].append(s.top)
289
+ converted_list[ii].append(s.left)
290
+ converted_list[ii].append(s.bottom)
291
+ converted_list[ii].append(s.right)
292
+ return converted_list
293
+
294
+
295
+
296
+ def manga_split(img, bbox=None, show_process=False, clip_width=False) -> list[TextSpan]:
297
+
298
+ im = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
299
+ imh, imw = im.shape[:2]
300
+
301
+ if bbox is None:
302
+ bbox = [0, 0, im.shape[1], im.shape[0]]
303
+ bboxes = [bbox]
304
+
305
+ span_list, _ = split_textblock(im, show_process=show_process, shrink=False, recheck=True, discard=False, crop_ratio=0)
306
+ if span_list is None:
307
+ return [TextSpan(0, 0, im.shape[1], im.shape[0])]
308
+ # span_list, _ = shrink_span_list(im, span_list, shrink_vert_space=False)
309
+
310
+ for ii, span in enumerate(span_list):
311
+ left = span.left
312
+ right = span.right
313
+ if ii == 0:
314
+ span.left = 0
315
+ else:
316
+ span.left = span.top
317
+ if ii == len(span_list) - 1:
318
+ span.right = im.shape[0]
319
+ else:
320
+ span.right = span.bottom
321
+ span.top = imw - right
322
+ span.bottom = imw - left
323
+ span.height = span.bottom - span.top
324
+ span.width = span.right - span.left
325
+
326
+ return span_list
327
+
328
+
329
+ def tessocr_img2text_linemode(img, span_list=None, combine_lines=True, show_process=False, gen_data=False, lang="comic6k", jpn_vert=False):
330
+ if jpn_vert:
331
+ lang = "jpn_vert"
332
+ img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
333
+ hig = img.shape[0]
334
+ wid = img.shape[1]
335
+ if hig * wid < 5:
336
+ return '', -1, -1
337
+
338
+ bw = 3
339
+ text = ''
340
+ alignment, vars = 0, (-1, -1)
341
+ if span_list is None:
342
+ span_list, vars = split_textblock(img, show_process=show_process)
343
+ _, maxspan = find_span(span_list, max)
344
+ maxh = bw*2 + maxspan.height
345
+ else:
346
+ maxh = max([s[2]-s[0] for s in span_list])
347
+ maxh = bw*2 + maxh
348
+
349
+ long_line = []
350
+ word_space = int(round(maxh / 8))
351
+ img = 255 - img
352
+ for ind, s in enumerate(span_list):
353
+ if isinstance(s, list):
354
+ im = img[s[0]: s[2], s[1]: s[3]]
355
+ else:
356
+ im = img[s.top: s.bottom, s.left: s.right]
357
+
358
+ hw1 = int(round((maxh - im.shape[0])/2))
359
+ hw2 = maxh - hw1 - im.shape[0]
360
+ dst = cv2.copyMakeBorder(im, hw1, hw2, word_space, word_space, cv2.BORDER_CONSTANT, None, value=[255, 255, 255])
361
+
362
+ if not combine_lines:
363
+ text += tessocr_img2text(dst, lang=lang) +'\n'
364
+ else:
365
+ long_line.append(dst)
366
+ if show_process:
367
+ cv2.imshow(str(ind), dst)
368
+
369
+ if combine_lines:
370
+ long_line = cv2.hconcat(long_line)
371
+ if jpn_vert:
372
+ long_line = cv2.rotate(long_line, cv2.ROTATE_90_CLOCKWISE)
373
+ if show_process:
374
+ cv2.namedWindow("long line:", cv2.WINDOW_NORMAL)
375
+ cv2.imshow("long line:", long_line)
376
+ if gen_data:
377
+ return long_line
378
+ res = tessocr_img2text(long_line, lang=lang)
379
+ mean_height = -1
380
+ if len(span_list) != 0:
381
+ if isinstance(span_list[0], list):
382
+ mean_height = np.mean(np.array([s[2]-s[0] for s in span_list]))
383
+ else:
384
+ mean_height = np.mean(np.array([s.height for s in span_list]))
385
+ alignment = 1 if vars[1] < vars[0] else 0
386
+ return res, mean_height, alignment
utils/stroke_width_calculator.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2, os, time
2
+ import numpy as np
3
+
4
+
5
+ def calculate_derivatives(gx, gy):
6
+ mag = np.sqrt(gx*gx + gy*gy)
7
+ if mag==0:
8
+ return False, -1, -1
9
+ else:
10
+ return True, gx / mag, gy / mag
11
+
12
+ def sw_calculator(mask, canny_img, gradient_x, gradient_y, show_process=False):
13
+ height, width = canny_img.shape[0], canny_img.shape[1]
14
+
15
+ if show_process:
16
+ drawborder = np.zeros((canny_img.shape[0], canny_img.shape[1], 3), dtype=np.uint8)
17
+
18
+ pnts = np.where(np.logical_and(canny_img != 0, mask!=0))
19
+ total_pnt_num = pnts[0].shape[0]
20
+ sample_pnt_num = 150
21
+ sample_step = total_pnt_num / sample_pnt_num if total_pnt_num > sample_pnt_num else 1
22
+
23
+ cur_pnt_ind = 0
24
+ ray_list = []
25
+
26
+ while cur_pnt_ind < total_pnt_num:
27
+ start_x, start_y = pnts[1][cur_pnt_ind], pnts[0][cur_pnt_ind]
28
+ ray_arr = [start_x, start_y, -1, -1, -1]
29
+ valid, dx, dy = calculate_derivatives(gradient_x[start_y][start_x], gradient_y[start_y][start_x])
30
+
31
+ if valid:
32
+ inc = 0.2
33
+ cur_x, cur_y = start_x + inc * dx, start_y + inc * dy
34
+ while (True):
35
+ tmp_curx, tmp_cury = int(cur_x), int(cur_y)
36
+ if tmp_curx < 0 or tmp_curx >= width or tmp_cury <= 0 or tmp_cury >= height:
37
+ break
38
+ if canny_img[tmp_cury][tmp_curx] == 0:
39
+ valid, dx_t, dy_t = calculate_derivatives(gradient_x[tmp_cury][tmp_curx], gradient_y[tmp_cury][tmp_curx])
40
+ if not valid:
41
+ break
42
+ if np.arccos(-dx * dx_t + -dy * dy_t) < np.pi / 2.0:
43
+ ray_arr[2] = tmp_curx
44
+ ray_arr[3] = tmp_cury
45
+ ray_arr[4] = np.sqrt((start_x - tmp_curx)**2 + (start_y - tmp_cury)**2)
46
+ break
47
+ cur_x += dx
48
+ cur_y += dy
49
+ if ray_arr[2] != -1:
50
+ ray_list.append(ray_arr)
51
+ if show_process:
52
+ drawborder = cv2.arrowedLine(drawborder, (ray_arr[0], ray_arr[1]), (ray_arr[2], ray_arr[3]),
53
+ (0, 255, 0), 1)
54
+
55
+ cur_pnt_ind += sample_step
56
+ cur_pnt_ind = int(round(cur_pnt_ind))
57
+ if show_process and len(ray_list) != 0:
58
+ ray_list.sort(key=lambda x: x[4])
59
+ cv2.imshow("border", drawborder)
60
+ cv2.imshow("cannyimg", canny_img)
61
+ cv2.waitKey(0)
62
+ return ray_list
63
+
64
+ def strokewidth_check(text_mask, labels, num_labels, stats, debug_type=0):
65
+ rays_width = []
66
+ height, width = text_mask.shape[0], text_mask.shape[1]
67
+
68
+ blur_img = cv2.dilate(text_mask ,(3,3),cv2.BORDER_DEFAULT)
69
+
70
+ # canny_img = cv2.Canny(cv2.dilate(text_mask, (3,3), 1), 170, 320, L2gradient=True, apertureSize=3)
71
+
72
+ _, canny_img = cv2.threshold(text_mask, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
73
+ blur2 = blur_img.astype(float) / 255
74
+ gradient_x = cv2.Scharr(blur2, ddepth=-1, dx=1, dy=0)
75
+ gradient_x = cv2.GaussianBlur(gradient_x ,(3, 3),cv2.BORDER_DEFAULT)
76
+ gradient_y = cv2.Scharr(blur2, ddepth=-1, dx=0, dy=1)
77
+ gradient_y = cv2.GaussianBlur(gradient_y ,(3, 3),cv2.BORDER_DEFAULT)
78
+
79
+ img_area = text_mask.shape[0] * text_mask.shape[1]
80
+ show_process = True if debug_type > 0 else False
81
+ for lab in range(num_labels):
82
+ stat = stats[lab]
83
+ if lab != 0 and stat[4] > img_area * 0.002:
84
+ x1, y1, x2, y2 = stat[0] - 2, stat[1] - 2, stat[0] + stat[2] + 2, stat[1] + stat[3] + 2
85
+ x1, x2 = max(x1, 0), min(x2, width)
86
+ y1, y2 = max(y1, 0), min(y2, height)
87
+ labcord = np.where(labels==lab)
88
+ labcord2 = (labcord[0] - y1, labcord[1] - x1)
89
+ text_roi = np.zeros((y2-y1, x2-x1), dtype=np.uint8)
90
+ text_roi[labcord2] = 255
91
+ text_roi = cv2.GaussianBlur(text_roi ,(3,3), cv2.BORDER_DEFAULT)
92
+ ray_list = sw_calculator(text_roi,
93
+ canny_img[y1: y2, x1: x2],
94
+ gradient_x[y1: y2, x1: x2],
95
+ gradient_y[y1: y2, x1: x2],
96
+ show_process=show_process)
97
+ if len(ray_list) != 0:
98
+ ray_list.sort(key=lambda x: x[4])
99
+ rays_width.append([int(lab), ray_list[int(len(ray_list)/2)][4]])
100
+
101
+ if len(rays_width) != 0:
102
+ rays_width = np.array(rays_width)
103
+ mean_width = np.mean(rays_width[:, 1])
104
+ ma = np.int0(rays_width[:, 0])
105
+ mean_area = np.mean(stats[ma][:, 4])
106
+
107
+ false_labels = np.where(rays_width[:, 1] > 2*mean_width)[0]
108
+ false_labels = rays_width[false_labels, 0].astype(np.int32)
109
+ for fl in false_labels:
110
+ if stats[fl][4] > 2 * mean_area:
111
+ text_mask[np.where(labels==fl)] = 0
112
+ return text_mask
113
+
utils/structures.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, List, ClassVar, Union, Any, Dict, Set
2
+ from dataclasses import dataclass, field, is_dataclass
3
+ import copy
4
+ import os
5
+
6
+ import numpy as np
7
+
8
+
9
+ # decorator to wrap original __init__
10
+ # https://www.geeksforgeeks.org/creating-nested-dataclass-objects-in-python/
11
+ def nested_dataclass(*args, **dataclass_kwargs):
12
+ '''
13
+ nested dataclass support \n
14
+ also ignore extra arguments
15
+ '''
16
+ def wrapper(check_class):
17
+
18
+ # passing class to investigate
19
+ check_class = dataclass(check_class, **dataclass_kwargs)
20
+ o_init = check_class.__init__
21
+
22
+ def __init__(self, *args, **kwargs):
23
+
24
+ store_deprecated = 'deprecated_attributes' in self.__annotations__
25
+ deprecated = {}
26
+ for name in list(kwargs.keys()):
27
+ if name not in self.__annotations__:
28
+ # print(f'warning: type object \'{self.__class__.__name__}\' has no attribute {name}, might be loading from an older config')
29
+ val = kwargs.pop(name)
30
+ if store_deprecated:
31
+ deprecated[name] = val
32
+ continue
33
+ value = kwargs[name]
34
+ # getting field type
35
+ ft = check_class.__annotations__.get(name, None)
36
+
37
+ if is_dataclass(ft) and isinstance(value, dict):
38
+ obj = ft(**value)
39
+ kwargs[name]= obj
40
+
41
+ if len(deprecated) > 0:
42
+ kwargs['deprecated_attributes'] = deprecated
43
+
44
+ o_init(self, *args, **kwargs)
45
+ check_class.__init__=__init__
46
+
47
+ return check_class
48
+
49
+ return wrapper(args[0]) if args else wrapper
50
+
51
+
52
+ @dataclass
53
+ class Config:
54
+
55
+ def update(self, key: str, value):
56
+ assert key in self.__annotations__, f'type object \'{self.__class__.__name__}\' has no attribute {key}'
57
+ self.__setattr__(key, value)
58
+
59
+ @classmethod
60
+ def annotations_set(cls):
61
+ return set(list(cls.__annotations__))
62
+
63
+ def __getitem__(self, key: str):
64
+ assert key in self.__annotations__, f'type object \'{self.__class__.__name__}\' has no attribute {key}'
65
+ return self.__getattribute__(key)
66
+
67
+ def __setitem__(self, key: str, value):
68
+ self.__setattr__(key, value)
69
+
70
+ @classmethod
71
+ def params(cls):
72
+ return cls.__annotations__
73
+
74
+ def merge(self, target):
75
+ tgt_keys = target.annotations_set()
76
+ for key in tgt_keys:
77
+ self.update(key, target[key])
78
+
79
+ def copy(self):
80
+ return copy.deepcopy(self)
81
+
82
+
83
+ MODULE_PATH = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
84
+ BASE_PATH = os.path.dirname(MODULE_PATH)
utils/text_layout.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import numpy as np
3
+
4
+ from .imgproc_utils import rotate_image
5
+ from .textblock import TextBlock, TextAlignment
6
+
7
+ class Line:
8
+
9
+ def __init__(self, text: str = '', pos_x: int = 0, pos_y: int = 0, length: float = 0, spacing: int = 0) -> None:
10
+ self.text = text
11
+ self.pos_x = pos_x
12
+ self.pos_y = pos_y
13
+ self.length = int(length)
14
+ self.num_words = 0
15
+ if text:
16
+ self.num_words += 1
17
+ self.spacing = 0
18
+ self.add_spacing(spacing)
19
+
20
+ def append_right(self, word: str, w_len: int, delimiter: str = ''):
21
+ self.text = self.text + delimiter + word
22
+ if word:
23
+ self.num_words += 1
24
+ self.length += w_len
25
+
26
+ def append_left(self, word: str, w_len: int, delimiter: str = ''):
27
+ self.text = word + delimiter + self.text
28
+ if word:
29
+ self.num_words += 1
30
+ self.length += w_len
31
+
32
+ def add_spacing(self, spacing: int):
33
+ self.spacing = spacing
34
+ self.pos_x -= spacing
35
+ self.length += 2 * spacing
36
+
37
+ def strip_spacing(self):
38
+ self.length -= self.spacing * 2
39
+ self.pos_x += self.spacing
40
+ self.spacing = 0
41
+
42
+ def line_is_valid(line: Line, new_len: int, delimiter_len, max_width, words_length, srcline_wlist, line_no: int, line_height, ref_src_lines: bool = False):
43
+ if ref_src_lines:
44
+ # if line_no >= 0 and line_no < len(srcline_wlist):
45
+ # _max_width = min(srcline_wlist[line_no], max_width)
46
+ # else:
47
+ # _max_width = max_width
48
+ if line_no >= 0 and line_no < len(srcline_wlist):
49
+ _max_width = srcline_wlist[line_no] * words_length
50
+ else:
51
+ _max_width = np.inf
52
+ _max_width = max(srcline_wlist) * words_length
53
+ _max_width = _max_width + delimiter_len * line.num_words
54
+ max_width = min(max_width, _max_width)
55
+
56
+ if new_len < max_width:
57
+ return True
58
+ else:
59
+ if line.length / max_width < max_width / new_len:
60
+ return True
61
+ else:
62
+ return False
63
+
64
+ def layout_lines_aligncenter(
65
+ blk: TextBlock,
66
+ mask: np.ndarray,
67
+ words: List[str],
68
+ centroid: List[int],
69
+ wl_list: List[int],
70
+ delimiter_len: int,
71
+ line_height: int,
72
+ spacing: int = 0,
73
+ delimiter: str = ' ',
74
+ max_central_width: float = np.inf,
75
+ word_break: bool = False,
76
+ ref_src_lines = False,
77
+ srcline_wlist=None,
78
+ start_from_top=False
79
+ )->List[Line]:
80
+
81
+ lh_pad = 0
82
+ if blk.line_spacing > 1:
83
+ lh_pad = int(np.ceil(line_height - line_height / blk.line_spacing))
84
+
85
+ centroid_x, centroid_y = centroid
86
+ adjust_x = adjust_y = 0
87
+
88
+ border_thr = 220
89
+
90
+ # layout the central line, the center word is approximately aligned with the centroid of the mask
91
+ num_words = len(words)
92
+ len_left, len_right = [], []
93
+ wlst_left, wlst_right = [], []
94
+ sum_left, sum_right = 0, 0
95
+ words_length = sum(wl_list)
96
+ if num_words > 1:
97
+ wl_array = np.array(wl_list, dtype=np.float64)
98
+ wl_cumsums = np.cumsum(wl_array)
99
+ wl_cumsums = wl_cumsums - wl_cumsums[-1] / 2 - wl_array / 2
100
+ central_index = np.argmin(np.abs(wl_cumsums))
101
+
102
+ if central_index > 0:
103
+ wlst_left = words[:central_index]
104
+ len_left = wl_list[:central_index]
105
+ sum_left = np.sum(len_left)
106
+ if central_index < num_words - 1:
107
+ wlst_right = words[central_index + 1:]
108
+ len_right = wl_list[central_index + 1:]
109
+ sum_right = np.sum(len_right)
110
+ else:
111
+ central_index = 0
112
+
113
+ pos_y = centroid_y - line_height // 2
114
+ pos_x = centroid_x - wl_list[central_index] // 2
115
+
116
+ bh, bw = mask.shape[:2]
117
+ central_line = Line(words[central_index], pos_x, pos_y, wl_list[central_index], spacing)
118
+ line_bottom = pos_y + line_height
119
+ while (sum_left > 0 or sum_right > 0) and not start_from_top:
120
+ left_valid, right_valid = False, False
121
+
122
+ if sum_left > 0:
123
+ new_len_l = central_line.length + len_left[-1] + delimiter_len
124
+ new_x_l = centroid_x - new_len_l // 2
125
+ new_r_l = new_x_l + new_len_l
126
+ if (new_x_l > 0 and new_r_l < bw):
127
+ if mask[pos_y: line_bottom - lh_pad, new_x_l].mean() > border_thr and \
128
+ mask[pos_y: line_bottom - lh_pad, new_r_l].mean() > border_thr:
129
+ left_valid = True
130
+ if sum_right > 0:
131
+ new_len_r = central_line.length + len_right[0] + delimiter_len
132
+ new_x_r = centroid_x - new_len_r // 2 - line_height // 2
133
+ new_r_r = centroid_x + new_len_r // 2 + line_height // 2
134
+ if (new_x_r > 0 and new_r_r < bw):
135
+ if mask[pos_y: line_bottom - lh_pad, new_x_r].mean() > border_thr and \
136
+ mask[pos_y: line_bottom - lh_pad, new_r_r].mean() > border_thr:
137
+ right_valid = True
138
+
139
+ insert_left = False
140
+ if left_valid and right_valid:
141
+ if sum_left > sum_right:
142
+ insert_left = True
143
+ elif left_valid:
144
+ insert_left = True
145
+ elif not right_valid:
146
+ break
147
+
148
+ if insert_left:
149
+ new_len = central_line.length + len_left[-1] + delimiter_len
150
+ else:
151
+ new_len = central_line.length + len_right[0] + delimiter_len
152
+
153
+ line_valid = line_is_valid(central_line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, -1, line_height, ref_src_lines)
154
+ if ref_src_lines and not line_valid and len(srcline_wlist) == 1:
155
+ if new_len < max_central_width:
156
+ line_valid = True
157
+ if not line_valid:
158
+ break
159
+
160
+ if insert_left:
161
+ central_line.append_left(wlst_left.pop(-1), len_left[-1] + delimiter_len, delimiter)
162
+ sum_left -= len_left.pop(-1)
163
+ central_line.pos_x = new_x_l
164
+ else:
165
+ central_line.append_right(wlst_right.pop(0), len_right[0] + delimiter_len, delimiter)
166
+ sum_right -= len_right.pop(0)
167
+ central_line.pos_x = new_x_r
168
+
169
+ line_right_no = line_left_no = 0
170
+ if ref_src_lines:
171
+ nl = len(srcline_wlist)
172
+ if nl % 2 == 0:
173
+ line_right_no = nl // 2
174
+ line_left_no = nl // 2 - 1
175
+ else:
176
+ line_right_no = nl // 2 + 1
177
+ line_left_no = nl // 2 - 1
178
+
179
+ if not start_from_top:
180
+ central_line.strip_spacing()
181
+ lines = [central_line]
182
+ else:
183
+ lines = []
184
+ sum_right = sum(wl_list)
185
+ sum_left = 0
186
+ wlst_right = words
187
+ len_right = wl_list
188
+ line_right_no = 0
189
+
190
+ # layout bottom half
191
+ if sum_right > 0:
192
+ w, wl = wlst_right.pop(0), len_right.pop(0)
193
+ pos_x = centroid_x - wl // 2
194
+ if start_from_top:
195
+ pos_y = centroid_y - int(blk.bounding_rect()[3] / 2)
196
+ else:
197
+ pos_y = centroid_y + line_height // 2
198
+ pos_y = max(0, min(pos_y, mask.shape[0] - 1))
199
+ top_mean = mask[pos_y, :].mean()
200
+ x_mean = mask.mean(axis=1)
201
+ base_mean = x_mean.max() / 2
202
+ if top_mean < base_mean:
203
+ available_y = np.where(
204
+ x_mean[pos_y:] > base_mean
205
+ )[0]
206
+ if len(available_y) > 0:
207
+ adjust_y = min(available_y[0], line_height)
208
+ pos_y = pos_y + adjust_y
209
+ line_bottom = pos_y + line_height
210
+ line = Line(w, pos_x, pos_y, wl, spacing)
211
+ lines.append(line)
212
+ sum_right -= wl
213
+ while sum_right > 0:
214
+ w, wl = wlst_right.pop(0), len_right.pop(0)
215
+ sum_right -= wl
216
+ new_len = line.length + wl + delimiter_len
217
+ new_x = centroid_x - new_len // 2 - line_height // 2
218
+ right_x = new_x + new_len + line_height // 2
219
+ if new_x < 0 or right_x >= bw:
220
+ line_valid = False
221
+ elif mask[pos_y: line_bottom - lh_pad, new_x].mean() < border_thr or\
222
+ mask[pos_y: line_bottom - lh_pad, right_x].mean() < border_thr:
223
+ line_valid = False
224
+ if ref_src_lines and (len(wl_list) == 1 or line_right_no + 1 >= len(srcline_wlist)) and \
225
+ line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_right_no, line_height, ref_src_lines):
226
+ line_valid = True
227
+ else:
228
+ line_valid = True
229
+ if line_valid:
230
+ line.append_right(w, wl+delimiter_len, delimiter)
231
+ line.pos_x = new_x
232
+ line_valid = line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_right_no, line_height, ref_src_lines)
233
+ if not line_valid:
234
+ if sum_right > 0:
235
+ w, wl = wlst_right.pop(0), len_right.pop(0)
236
+ sum_right -= wl
237
+ else:
238
+ line.strip_spacing()
239
+ break
240
+
241
+ if not line_valid:
242
+ pos_x = centroid_x - wl // 2
243
+ pos_y = line_bottom
244
+ line_bottom += line_height
245
+ line.strip_spacing()
246
+ line = Line(w, pos_x, pos_y, wl, spacing)
247
+ lines.append(line)
248
+ line_right_no += 1
249
+
250
+ # layout top half
251
+ if sum_left > 0:
252
+ w, wl = wlst_left.pop(-1), len_left.pop(-1)
253
+ pos_x = centroid_x - wl // 2
254
+ pos_y = centroid_y - line_height // 2 - line_height
255
+ pos_y = max(0, min(pos_y, mask.shape[0] - 1))
256
+ line_bottom = pos_y + line_height
257
+ line = Line(w, pos_x, pos_y, wl, spacing)
258
+ lines.insert(0, line)
259
+ sum_left -= wl
260
+ while sum_left > 0:
261
+ w, wl = wlst_left.pop(-1), len_left.pop(-1)
262
+ sum_left -= wl
263
+ new_len = line.length + wl + delimiter_len
264
+ new_x = centroid_x - new_len // 2 - line_height // 2
265
+ right_x = new_x + new_len + line_height // 2
266
+ if new_x <= 0 or right_x >= bw:
267
+ line_valid = False
268
+ elif mask[pos_y: line_bottom - lh_pad, new_x].mean() < border_thr or\
269
+ mask[pos_y: line_bottom - lh_pad, right_x].mean() < border_thr:
270
+ line_valid = False
271
+ if ref_src_lines and line_left_no - 1 < 0 and \
272
+ line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_left_no, line_height, ref_src_lines):
273
+ line_valid = True
274
+ else:
275
+ line_valid = True
276
+ if line_valid:
277
+ line.append_left(w, wl+delimiter_len, delimiter)
278
+ line.pos_x = new_x
279
+ line_valid = line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_left_no, line_height, ref_src_lines)
280
+ if not line_valid:
281
+ if sum_left > 0:
282
+ w, wl = wlst_left.pop(-1), len_left.pop(-1)
283
+ sum_left -= wl
284
+ else:
285
+ line.strip_spacing()
286
+ break
287
+
288
+ if not line_valid :
289
+ pos_x = centroid_x - wl // 2
290
+ pos_y -= line_height
291
+ line_bottom = pos_y + line_height
292
+ line.strip_spacing()
293
+ line = Line(w, pos_x, pos_y, wl, spacing)
294
+ lines.insert(0, line)
295
+ line_left_no -= 1
296
+
297
+ return lines, (adjust_x, adjust_y)
298
+
299
+ def layout_lines_alignside(
300
+ blk: TextBlock,
301
+ mask: np.ndarray,
302
+ words: List[str],
303
+ origin: List[int],
304
+ wl_list: List[int],
305
+ delimiter_len: int,
306
+ line_height: int,
307
+ spacing: int = 0,
308
+ delimiter: str = ' ',
309
+ word_break: bool = False,
310
+ max_width: int = np.inf,
311
+ ref_src_lines = False,
312
+ srcline_wlist=None,
313
+ )->List[Line]:
314
+
315
+ align_right = blk.fontformat.alignment == TextAlignment.Right
316
+
317
+ ox, oy = origin
318
+ bh, bw = mask.shape[:2]
319
+ num_words = len(words)
320
+ blk_rect = blk.bounding_rect()
321
+ blk_width = blk_rect[2]
322
+ lines = []
323
+ words_length = sum(wl_list)
324
+
325
+ lh_pad = 0
326
+ if blk.line_spacing > 1:
327
+ lh_pad = int(np.ceil(line_height - line_height / blk.line_spacing))
328
+
329
+ if num_words > 0:
330
+ sum_right = np.array(wl_list).sum()
331
+ w, wl = words.pop(0), wl_list.pop(0)
332
+ line = Line(w, ox, oy, wl)
333
+ lines.append(line)
334
+ sum_right -= wl
335
+ line_bottom = oy + line_height
336
+ pos_y = oy
337
+ line_id = 0
338
+ while sum_right > 0:
339
+ w, wl = words.pop(0), wl_list.pop(0)
340
+ sum_right -= wl
341
+ new_len = line.length + wl + delimiter_len
342
+ if align_right:
343
+ new_x = ox + blk_width - new_len - line_height // 2
344
+ else:
345
+ new_x = ox + new_len + line_height // 2
346
+ line_valid = False
347
+ if new_x < bw and new_x > 0:
348
+ if mask[np.clip(pos_y, 0, bh - 1): np.clip(line_bottom - lh_pad, 0, bh), new_x].mean() > 240:
349
+ line_valid = True
350
+ else:
351
+ if ref_src_lines and line_id + 1 >= len(srcline_wlist) and line_is_valid(line, new_len, delimiter_len, max_width, words_length, srcline_wlist, line_id, line_height, ref_src_lines):
352
+ line_valid = True
353
+ if line_valid:
354
+ line_valid = line_is_valid(line, new_len, delimiter_len, max_width, words_length, srcline_wlist, line_id, line_height, ref_src_lines)
355
+ if line_valid:
356
+ line.append_right(w, wl+delimiter_len, delimiter)
357
+ else:
358
+ pos_y = line_bottom
359
+ line_bottom += line_height
360
+ line = Line(w, ox, pos_y, wl)
361
+ line_id += 1
362
+ lines.append(line)
363
+ return lines, (0, 0)
364
+
365
+
366
+
367
+ def layout_text(
368
+ blk: TextBlock,
369
+ mask: np.ndarray,
370
+ mask_xyxy: List,
371
+ centroid: List,
372
+ words: List[str],
373
+ wl_list: List[int],
374
+ delimiter: str,
375
+ delimiter_len: int,
376
+ line_height: int,
377
+ spacing: int = 0,
378
+ max_central_width=np.inf,
379
+ src_is_cjk=False,
380
+ tgt_is_cjk=False,
381
+ ref_src_lines = False
382
+ ) -> Tuple[str, List]:
383
+
384
+ angle = blk.angle
385
+ alignment = blk.alignment
386
+
387
+ start_from_top = False
388
+ srcline_wlist = None
389
+
390
+ if ref_src_lines:
391
+ srcline_wlist, srcline_width = blk.normalizd_width_list(normalize=False)
392
+ # tgtline_width = sum(wl_list) + delimiter_len * max(len(wl_list) - 1, 0)
393
+ # if tgtline_width < srcline_width:
394
+ # min_bbox = blk.min_rect(rotate_back=True)[0]
395
+ # x1, y1 = min_bbox[0]
396
+ # x2, y2 = min_bbox[2]
397
+ # w = x2 - x1
398
+ # max_central_width = min(max_central_width, w)
399
+ # pass
400
+
401
+ if alignment == TextAlignment.Center and \
402
+ len(srcline_wlist) > 1:
403
+ if len(srcline_wlist) == 2:
404
+ start_from_top = True
405
+ else:
406
+ nw = len(srcline_wlist)
407
+ # nl = min(nw // 2, 2)
408
+ nl = 1
409
+ sum_top = sum(srcline_wlist[:nl])
410
+ sum_btn = sum(srcline_wlist[-nl:])
411
+ start_from_top = sum_top / sum_btn > 1.2 and srcline_wlist[0] / max(srcline_wlist) > 0.9
412
+
413
+ srcline_wlist = np.array(srcline_wlist) / srcline_width
414
+ srcline_wlist = srcline_wlist.tolist()
415
+ # line_height = min((blk.detected_font_size), line_height)
416
+
417
+ # if ref_src_lines:
418
+ # mask = np.ones_like(mask) * 255
419
+
420
+ if max_central_width == np.inf:
421
+ max_central_width = mask.shape[1]
422
+
423
+ centroid_x, centroid_y = centroid
424
+ center_x = mask_xyxy[0] + centroid_x
425
+ center_y = mask_xyxy[1] + centroid_y
426
+ shifted_x, shifted_y = 0, 0
427
+ if abs(angle) > 0:
428
+
429
+ old_h, old_w = mask.shape[:2]
430
+ old_origin = (old_w // 2, old_h // 2)
431
+ rel_cx, rel_cy = centroid[0] - old_origin[0], centroid[1] - old_origin[1]
432
+
433
+ mask = rotate_image(mask, angle)
434
+ rad = np.deg2rad(angle)
435
+ r_sin, r_cos = np.sin(rad), np.cos(rad)
436
+ new_rel_cy = -rel_cx * r_sin + rel_cy * r_cos
437
+ new_rel_cx = rel_cy * r_sin + rel_cx * r_cos
438
+
439
+ shifted_x, shifted_y = new_rel_cx - rel_cx, new_rel_cy - rel_cy
440
+
441
+ new_h, new_w = mask.shape[:2]
442
+ new_origin = (new_w // 2, new_h // 2)
443
+ new_cx, new_cy = new_origin[0] + new_rel_cx, new_origin[1] + new_rel_cy
444
+ centroid = [int(new_cx), int(new_cy)]
445
+
446
+ if alignment == TextAlignment.Center:
447
+ lines, adjust_xy = layout_lines_aligncenter(blk, mask, words, centroid, wl_list, delimiter_len, line_height, spacing, delimiter,
448
+ max_central_width, ref_src_lines=ref_src_lines, srcline_wlist=srcline_wlist,
449
+ start_from_top=start_from_top)
450
+ else:
451
+ lines, adjust_xy = layout_lines_alignside(blk, mask, words, centroid, wl_list, delimiter_len, line_height, spacing, delimiter, False, max_central_width,
452
+ ref_src_lines=ref_src_lines, srcline_wlist=srcline_wlist)
453
+
454
+ concated_text = []
455
+ pos_x_lst, pos_right_lst = [], []
456
+ for line in lines:
457
+ pos_x_lst.append(line.pos_x)
458
+ pos_right_lst.append(max(line.pos_x, 0) + line.length)
459
+ concated_text.append(line.text)
460
+ concated_text = '\n'.join(concated_text)
461
+
462
+ pos_x_lst = np.array(pos_x_lst)
463
+ pos_right_lst = np.array(pos_right_lst)
464
+ canvas_l, canvas_r = pos_x_lst.min(), pos_right_lst.max()
465
+ canvas_t, canvas_b = lines[0].pos_y, lines[-1].pos_y + line_height
466
+
467
+ canvas_h = int(canvas_b - canvas_t)
468
+ canvas_w = int(canvas_r - canvas_l)
469
+
470
+ if alignment == 1:
471
+ abs_x = int(round(center_x - canvas_w / 2))
472
+ abs_y = int(round(center_y - canvas_h / 2))
473
+ else:
474
+ abs_x = shifted_x
475
+ abs_y = shifted_y
476
+
477
+ return concated_text, [abs_x, abs_y, canvas_w, canvas_h], start_from_top, adjust_xy
utils/text_processing.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple
2
+ import json
3
+ import os.path as osp
4
+ import os
5
+
6
+ HALF2FULL = {i: i + 0xFEE0 for i in range(0x21, 0x7F)}
7
+ HALF2FULL[0x20] = 0x3000
8
+
9
+ FULL2HALF = dict((i + 0xFEE0, i) for i in range(0x21, 0x7F))
10
+ FULL2HALF[0x3000] = 0x20
11
+ FULL2HALF[0x3002] = 0x2E
12
+
13
+ LANGSET_CJK = {'简体中文', '繁體中文', '日本語'}
14
+ LANGSET_CH = {'简体中文', '繁體中文'}
15
+
16
+ PUNSET_RIGHT_ENG = {'.', '?', '!', ':', ';', ')', '}', "\""}
17
+ PUNCTUATION_L = {'「', '『', '【', '《', '〈', '〔', '[', '{', '(', '(', '[', '{', '“', '‘'}
18
+
19
+ PKUSEG_PUNCSET = {' ', '.', ' '}
20
+ PKUSEGPATH = r'data/pkusegscores.json'
21
+ PKUSEGSCORES = None
22
+ CHSEG = None
23
+
24
+ def full_len(s: str):
25
+ """
26
+ Convert all ASCII characters to their full-width counterpart.
27
+ https://stackoverflow.com/questions/2422177/python-how-can-i-replace-full-width-characters-with-half-width-characters
28
+ """
29
+ return s.translate(HALF2FULL)
30
+
31
+ def half_len(s):
32
+ '''
33
+ Convert full-width characters to ASCII counterpart
34
+ '''
35
+ return s.translate(FULL2HALF)
36
+
37
+ def seg_to_chars(text: str) -> List[str]:
38
+ text = text.replace('\n', '')
39
+ return [c for c in text]
40
+
41
+ def seg_eng(text: str) -> List[str]:
42
+ text = text.replace(' ', ' ').replace(' .', '.').replace('\n', ' ')
43
+ processed_text = ''
44
+
45
+ # dumb way to insure spaces between words
46
+ text_len = len(text)
47
+ for ii, c in enumerate(text):
48
+ if c in PUNSET_RIGHT_ENG and ii < text_len - 1:
49
+ next_c = text[ii + 1]
50
+ if next_c.isalpha() or next_c.isnumeric():
51
+ processed_text += c + ' '
52
+ else:
53
+ processed_text += c
54
+ else:
55
+ processed_text += c
56
+
57
+ word_list = processed_text.split(' ')
58
+ word_num = len(word_list)
59
+ if word_num <= 1:
60
+ return word_list
61
+
62
+ words = []
63
+ skip_next = False
64
+ for ii, word in enumerate(word_list):
65
+ if skip_next:
66
+ skip_next = False
67
+ continue
68
+ if len(word) < 3:
69
+ append_left, append_right = False, False
70
+ len_word, len_next, len_prev = len(word), -1, -1
71
+ if ii < word_num - 1:
72
+ len_next = len(word_list[ii + 1])
73
+ if ii > 0:
74
+ len_prev = len(words[-1])
75
+ cond_next = (len_word == 2 and len_next <= 4) or len_word == 1
76
+ cond_prev = (len_word == 2 and len_prev <= 4) or len_word == 1
77
+ if len_next > 0 and len_prev > 0:
78
+ if len_next < len_prev:
79
+ append_right = cond_next
80
+ else:
81
+ append_left = cond_prev
82
+ elif len_next > 0:
83
+ append_right = cond_next
84
+ elif len_prev > 0:
85
+ append_left = cond_prev
86
+
87
+ if append_left:
88
+ words[-1] = words[-1] + ' ' + word
89
+ elif append_right:
90
+ words.append(word + ' ' + word_list[ii + 1])
91
+ skip_next = True
92
+ else:
93
+ words.append(word)
94
+ continue
95
+ words.append(word)
96
+ return words
97
+
98
+ def _seg_ch_pkg(text: str) -> List[str]:
99
+
100
+ if text == ' ':
101
+ return [' ']
102
+ elif text == '':
103
+ return []
104
+
105
+ segments = CHSEG.cut(text)
106
+ num_segments = len(segments)
107
+ if num_segments == 0:
108
+ return []
109
+ if num_segments == 1:
110
+ return [segments[0][0]]
111
+
112
+ words = []
113
+ tags = []
114
+ max_concat_len = 4
115
+ skip_next = False
116
+ try:
117
+ for ii, (word, tag) in enumerate(segments):
118
+ if skip_next:
119
+ skip_next = False
120
+ continue
121
+
122
+ len_word, len_next, len_prev = len(word), -1, -1
123
+ next_valid, prev_valid = False, False
124
+ word_next, tag_next = '', ''
125
+ word_prev, tag_prev = '', ''
126
+ score_next, score_prev = 0, 0
127
+ if ii < num_segments - 1:
128
+ word_next, tag_next = segments[ii + 1]
129
+ len_next = len(word_next)
130
+ next_valid = True
131
+ if tag_next != 'w' and not word_next in PKUSEG_PUNCSET:
132
+ score_next = PKUSEGSCORES[tag][tag_next]
133
+
134
+ if ii > 0:
135
+ word_prev, tag_prev = words[-1], segments[ii - 1][1]
136
+ len_prev = len(word_prev)
137
+ prev_valid = True
138
+ if tag_prev != 'w' and not word_prev[-1] in PKUSEG_PUNCSET:
139
+ score_prev = PKUSEGSCORES[tag_prev][tag]
140
+
141
+ append_prev, append_next = False, False
142
+
143
+ if tag == 'w' or word in PKUSEG_PUNCSET: # puntuation
144
+ if word in PUNCTUATION_L:
145
+ append_next = next_valid
146
+ elif len_word <= 1:
147
+ append_prev = prev_valid
148
+ else:
149
+ next_valid = score_next > 0 and len_next < max_concat_len
150
+ prev_valid = score_prev > 0 and len_prev < max_concat_len
151
+ need_concat = len_word < max_concat_len
152
+ append_prev = score_prev == 1
153
+ append_next = score_next == 1
154
+ if score_prev != 1 and score_next != 1 and need_concat:
155
+ append_prev = prev_valid
156
+ append_next = next_valid
157
+ if append_next and append_prev:
158
+ if len_prev == len_next:
159
+ if score_prev >= score_next:
160
+ append_next = False
161
+ else:
162
+ append_prev = False
163
+ elif len_prev < len_next:
164
+ append_next = False
165
+ else:
166
+ append_prev = False
167
+
168
+ if append_next and append_prev:
169
+ words[-1] = word_prev + word + word_next
170
+ tags[-1] = tags[-1] + [tag, tag_next]
171
+ skip_next = True
172
+ elif append_prev:
173
+ words[-1] = words[-1] + word
174
+ tags[-1].append(tag)
175
+ elif append_next:
176
+ words.append(word + word_next)
177
+ tags.append([tag, tag_next])
178
+ skip_next = True
179
+ else:
180
+ words.append(word)
181
+ tags.append([tag])
182
+ except Exception as e:
183
+ print('exp at line: ', text)
184
+ raise e
185
+ return words
186
+
187
+ def seg_ch_pkg(text: str):
188
+
189
+ global CHSEG
190
+ if CHSEG is None:
191
+ try:
192
+ import pkuseg
193
+ except:
194
+ import spacy_pkuseg as pkuseg
195
+ CHSEG = pkuseg.pkuseg(postag=True)
196
+
197
+ # pkuseg won't work with half-width punctuations
198
+ fullen_text = full_len(text).replace(' ', ' ')
199
+ cvt_back = False
200
+ if fullen_text != text:
201
+ cvt_back = True
202
+ text = fullen_text
203
+
204
+ global PKUSEGSCORES
205
+ if PKUSEGSCORES is None:
206
+ with open(PKUSEGPATH, 'r', encoding='utf8') as f:
207
+ PKUSEGSCORES = json.loads(f.read())
208
+
209
+ text_list = text.replace('\n', '').replace(' ', ' ').split(' ')
210
+ result_list = []
211
+ for ii, text in enumerate(text_list):
212
+ words = None
213
+ if text:
214
+ words = _seg_ch_pkg(text)
215
+ if words is not None:
216
+ if ii > 0:
217
+ words[0] = ' ' + words[0]
218
+ result_list.extend(words)
219
+
220
+ if cvt_back:
221
+ # pkuseg w
222
+ result_list = [half_len(word) for word in result_list]
223
+ return result_list
224
+
225
+ def seg_text(text: str, lang: str) -> Tuple[List, str]:
226
+ delimiter = ''
227
+ if lang in LANGSET_CH:
228
+ words = seg_ch_pkg(text)
229
+ elif lang in LANGSET_CJK:
230
+ words = seg_to_chars(text)
231
+ else:
232
+ words = seg_eng(text)
233
+ delimiter = ' '
234
+ return words, delimiter
235
+
236
+ def is_cjk(lang: str) -> bool:
237
+ return lang in LANGSET_CJK
utils/textblock.py ADDED
@@ -0,0 +1,908 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Callable
2
+ import numpy as np
3
+ from shapely.geometry import Polygon
4
+ import math
5
+ import copy
6
+ import cv2
7
+ import re
8
+
9
+ from .imgproc_utils import union_area, xywh2xyxypoly, rotate_polygons, color_difference
10
+ from .structures import Union, List, Dict, field, nested_dataclass
11
+ from .split_text_region import split_textblock as split_text_region
12
+ from .fontformat import FontFormat, LineSpacingType, TextAlignment, fix_fontweight_qt
13
+ from .textblock_mask import canny_flood
14
+ from .textlines_merge import sort_pnts, Quadrilateral, merge_bboxes_text_region
15
+
16
+
17
+ LANG_LIST = ['eng', 'ja', 'unknown']
18
+ LANGCLS2IDX = {'eng': 0, 'ja': 1, 'unknown': 2}
19
+
20
+ # https://ayaka.shn.hk/hanregex/
21
+ # https://medium.com/the-artificial-impostor/detecting-chinese-characters-in-unicode-strings-4ac839ba313a
22
+ CJKPATTERN = re.compile(r'[\uac00-\ud7a3\u3040-\u30ff\u4e00-\u9FFF]')
23
+
24
+
25
+ @nested_dataclass
26
+ class TextBlock:
27
+ xyxy: List = field(default_factory = lambda: [0, 0, 0, 0])
28
+ lines: List = field(default_factory = lambda: [])
29
+ language: str = 'unknown'
30
+ # font_size: float = -1.
31
+ distance: np.ndarray = None
32
+ angle: int = 0
33
+ vec: List = None
34
+ norm: float = -1
35
+ merged: bool = False
36
+ text: List = field(default_factory = lambda : [])
37
+ translation: str = ""
38
+ rich_text: str = ""
39
+ _bounding_rect: List = None
40
+ src_is_vertical: bool = None
41
+ _detected_font_size: float = -1
42
+ det_model: str = None
43
+
44
+ region_mask: np.ndarray = None
45
+ region_inpaint_dict: Dict = None
46
+
47
+ fontformat: FontFormat = field(default_factory=lambda: FontFormat())
48
+
49
+ deprecated_attributes: dict = field(default_factory = lambda: dict())
50
+
51
+ @property
52
+ def vertical(self):
53
+ return self.fontformat.vertical
54
+
55
+ @vertical.setter
56
+ def vertical(self, value: bool):
57
+ self.fontformat.vertical = value
58
+
59
+ @property
60
+ def font_size(self):
61
+ return self.fontformat.font_size
62
+
63
+ @font_size.setter
64
+ def font_size(self, value: float):
65
+ self.fontformat.font_size = value
66
+
67
+ @property
68
+ def line_spacing(self):
69
+ return self.fontformat.line_spacing
70
+
71
+ @line_spacing.setter
72
+ def line_spacing(self, value: float):
73
+ self.fontformat.line_spacing = value
74
+
75
+ @property
76
+ def letter_spacing(self):
77
+ return self.fontformat.letter_spacing
78
+
79
+ @letter_spacing.setter
80
+ def letter_spacing(self, value: float):
81
+ self.fontformat.letter_spacing = value
82
+
83
+ @property
84
+ def font_family(self):
85
+ return self.fontformat.font_family
86
+
87
+ @font_family.setter
88
+ def font_family(self, value: str):
89
+ self.fontformat.font_family = value
90
+
91
+ @property
92
+ def font_weight(self):
93
+ return self.fontformat.font_weight
94
+
95
+ @font_weight.setter
96
+ def font_weight(self, value: int):
97
+ self.fontformat.font_weight = value
98
+
99
+ @property
100
+ def bold(self):
101
+ return self.fontformat.bold
102
+
103
+ @bold.setter
104
+ def bold(self, value: bool):
105
+ self.fontformat.bold = value
106
+
107
+ @property
108
+ def italic(self):
109
+ return self.fontformat.italic
110
+
111
+ @italic.setter
112
+ def italic(self, value: bool):
113
+ self.fontformat.italic = value
114
+
115
+ @property
116
+ def underline(self):
117
+ return self.fontformat.underline
118
+
119
+ @underline.setter
120
+ def underline(self, value: bool):
121
+ self.fontformat.underline = value
122
+
123
+ @property
124
+ def stroke_width(self):
125
+ return self.fontformat.stroke_width
126
+
127
+ @stroke_width.setter
128
+ def stroke_width(self, value: float):
129
+ self.fontformat.stroke_width = value
130
+
131
+ @property
132
+ def opacity(self):
133
+ return self.fontformat.opacity
134
+
135
+ @opacity.setter
136
+ def opacity(self, value: float):
137
+ self.fontformat.opacity = value
138
+
139
+ @property
140
+ def shadow_radius(self):
141
+ return self.fontformat.shadow_radius
142
+
143
+ @shadow_radius.setter
144
+ def shadow_radius(self, value: float):
145
+ self.fontformat.shadow_radius = value
146
+
147
+ @property
148
+ def shadow_strength(self):
149
+ return self.fontformat.shadow_strength
150
+
151
+ @shadow_strength.setter
152
+ def shadow_strength(self, value: float):
153
+ self.fontformat.shadow_strength = value
154
+
155
+ @property
156
+ def shadow_color(self):
157
+ return self.fontformat.shadow_color
158
+
159
+ @shadow_color.setter
160
+ def shadow_color(self, value: float):
161
+ self.fontformat.shadow_color = value
162
+
163
+ @property
164
+ def shadow_offset(self):
165
+ return self.fontformat.shadow_offset
166
+
167
+ @shadow_offset.setter
168
+ def shadow_offset(self, value: float):
169
+ self.fontformat.shadow_offset = value
170
+
171
+ @property
172
+ def fg_colors(self):
173
+ return self.fontformat.frgb
174
+
175
+ @fg_colors.setter
176
+ def fg_colors(self, value: Union[np.ndarray, List]):
177
+ self.fontformat.frgb = value
178
+
179
+ @property
180
+ def bg_colors(self):
181
+ return self.fontformat.srgb
182
+
183
+ @bg_colors.setter
184
+ def bg_colors(self, value: np.ndarray):
185
+ self.fontformat.srgb = value
186
+
187
+ @property
188
+ def alignment(self):
189
+ return self.fontformat.alignment
190
+
191
+ @alignment.setter
192
+ def alignment(self, value: int):
193
+ self.fontformat.alignment = value
194
+
195
+ def __post_init__(self):
196
+ if self.xyxy is not None:
197
+ self.xyxy = [int(num) for num in self.xyxy]
198
+ if self.distance is not None:
199
+ self.distance = np.array(self.distance, np.float32)
200
+ if self.vec is not None:
201
+ self.vec = np.array(self.vec, np.float32)
202
+ if self.src_is_vertical is None:
203
+ self.src_is_vertical = self.vertical
204
+
205
+ if self.rich_text:
206
+ self.rich_text = fix_fontweight_qt(self.rich_text)
207
+
208
+ da = self.deprecated_attributes
209
+ if len(da) > 0:
210
+ if 'accumulate_color' in da:
211
+ self.fg_colors = np.array([da['fg_r'], da['fg_g'], da['fg_b']], dtype=np.float32)
212
+ self.bg_colors = np.array([da['bg_r'], da['bg_g'], da['bg_b']], dtype=np.float32)
213
+ nlines = len(self)
214
+ if da['accumulate_color'] and len(self) > 0:
215
+ self.fg_colors /= nlines
216
+ self.bg_colors /= nlines
217
+
218
+ deprecated_blk_fmt_keys = {'vertical': None, 'line_spacing': None, 'letter_spacing': None, 'bold': None, 'underline': None, 'italic': None,
219
+ 'opacity': None, 'shadow_radius': None, 'shadow_strength': None, 'shadow_color': None, 'shadow_offset': None,
220
+ 'font_size': 'size', 'font_family': None, '_alignment': 'alignment', 'default_stroke_width': 'stroke_width', 'font_weight': None,
221
+ 'fg_colors': 'frgb', 'bg_colors': 'srgb'
222
+ }
223
+ for src_k, v in da.items():
224
+ if src_k in deprecated_blk_fmt_keys:
225
+ if deprecated_blk_fmt_keys[src_k] is None:
226
+ tgt_k = src_k
227
+ else:
228
+ tgt_k = deprecated_blk_fmt_keys[src_k]
229
+ setattr(self.fontformat, tgt_k, v)
230
+ self.font_weight = fix_fontweight_qt(self.font_weight)
231
+
232
+ del self.deprecated_attributes
233
+
234
+ @property
235
+ def detected_font_size(self):
236
+ if self._detected_font_size > 0:
237
+ return self._detected_font_size
238
+ return self.font_size
239
+
240
+ def adjust_bbox(self, with_bbox=False, x_range=None, y_range=None):
241
+ lines = self.lines_array().astype(np.int32)
242
+ if with_bbox:
243
+ self.xyxy[0] = min(lines[..., 0].min(), self.xyxy[0])
244
+ self.xyxy[1] = min(lines[..., 1].min(), self.xyxy[1])
245
+ self.xyxy[2] = max(lines[..., 0].max(), self.xyxy[2])
246
+ self.xyxy[3] = max(lines[..., 1].max(), self.xyxy[3])
247
+ else:
248
+ self.xyxy[0] = lines[..., 0].min()
249
+ self.xyxy[1] = lines[..., 1].min()
250
+ self.xyxy[2] = lines[..., 0].max()
251
+ self.xyxy[3] = lines[..., 1].max()
252
+
253
+ if x_range is not None:
254
+ self.xyxy[0] = np.clip(self.xyxy[0], x_range[0], x_range[1])
255
+ self.xyxy[2] = np.clip(self.xyxy[2], x_range[0], x_range[1])
256
+ if y_range is not None:
257
+ self.xyxy[1] = np.clip(self.xyxy[1], y_range[0], y_range[1])
258
+ self.xyxy[3] = np.clip(self.xyxy[3], y_range[0], y_range[1])
259
+
260
+ def sort_lines(self):
261
+ if self.distance is not None:
262
+ idx = np.argsort(self.distance)
263
+ self.distance = self.distance[idx]
264
+ lines = np.array(self.lines, dtype=np.int32)
265
+ self.lines = lines[idx].tolist()
266
+
267
+ def lines_array(self, dtype=np.float64):
268
+ return np.array(self.lines, dtype=dtype)
269
+
270
+ def set_lines_by_xywh(self, xywh: np.ndarray, angle=0, x_range=None, y_range=None, adjust_bbox=False):
271
+ if isinstance(xywh, List):
272
+ xywh = np.array(xywh)
273
+ lines = xywh2xyxypoly(np.array([xywh]))
274
+ if angle != 0:
275
+ cx, cy = xywh[0], xywh[1]
276
+ cx += xywh[2] / 2.
277
+ cy += xywh[3] / 2.
278
+ lines = rotate_polygons([cx, cy], lines, angle)
279
+
280
+ lines = lines.reshape(-1, 4, 2)
281
+ if x_range is not None:
282
+ lines[..., 0] = np.clip(lines[..., 0], x_range[0], x_range[1])
283
+ if y_range is not None:
284
+ lines[..., 1] = np.clip(lines[..., 1], y_range[0], y_range[1])
285
+ self.lines = lines.tolist()
286
+
287
+ if adjust_bbox:
288
+ self.adjust_bbox()
289
+
290
+ def aspect_ratio(self) -> float:
291
+ min_rect = self.min_rect()
292
+ middle_pnts = (min_rect[:, [1, 2, 3, 0]] + min_rect) / 2
293
+ norm_v = np.linalg.norm(middle_pnts[:, 2] - middle_pnts[:, 0])
294
+ norm_h = np.linalg.norm(middle_pnts[:, 1] - middle_pnts[:, 3])
295
+ return norm_v / norm_h
296
+
297
+ def center(self) -> np.ndarray:
298
+ xyxy = np.array(self.xyxy)
299
+ return (xyxy[:2] + xyxy[2:]) / 2
300
+
301
+ def unrotated_polygons(self, ids=None) -> np.ndarray:
302
+ angled = self.angle != 0
303
+ center = self.center()
304
+ polygons = self.lines_array().reshape(-1, 8)
305
+ if ids is not None:
306
+ polygons = polygons[ids]
307
+ if angled:
308
+ polygons = rotate_polygons(center, polygons, self.angle)
309
+ return angled, center, polygons
310
+
311
+ def min_rect(self, rotate_back=True, ids=None) -> List[int]:
312
+ angled, center, polygons = self.unrotated_polygons(ids=ids)
313
+ min_x = polygons[:, ::2].min()
314
+ min_y = polygons[:, 1::2].min()
315
+ max_x = polygons[:, ::2].max()
316
+ max_y = polygons[:, 1::2].max()
317
+ min_bbox = np.array([[min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]])
318
+ if angled and rotate_back:
319
+ min_bbox = rotate_polygons(center, min_bbox, -self.angle)
320
+ return min_bbox.reshape(-1, 4, 2).astype(np.int64)
321
+
322
+ def normalizd_width_list(self, normalize=True):
323
+ angled, center, polygons = self.unrotated_polygons()
324
+ width_list = []
325
+ for polygon in polygons:
326
+ width_list.append((polygon[[2, 4]] - polygon[[0, 6]]).mean())
327
+ sum_width = sum(width_list)
328
+ if normalize:
329
+ width_list = np.array(width_list)
330
+ width_list = width_list / sum_width
331
+ width_list = width_list.tolist()
332
+ return width_list, sum_width
333
+
334
+ # equivalent to qt's boundingRect, ignore angle
335
+ def bounding_rect(self) -> List[int]:
336
+ if self._bounding_rect is None:
337
+ # if True:
338
+ min_bbox = self.min_rect(rotate_back=False)[0]
339
+ x, y = min_bbox[0]
340
+ w, h = min_bbox[2] - min_bbox[0]
341
+ return [int(x), int(y), int(w), int(h)]
342
+ return self._bounding_rect
343
+
344
+ def __getattribute__(self, name: str):
345
+ if name == 'pts':
346
+ return self.lines_array()
347
+ # else:
348
+ return object.__getattribute__(self, name)
349
+
350
+ def __len__(self):
351
+ return len(self.lines)
352
+
353
+ def __getitem__(self, idx):
354
+ return self.lines[idx]
355
+
356
+ def to_dict(self, deep_copy=False):
357
+ blk_dict = vars(self)
358
+ if deep_copy:
359
+ blk_dict = copy.deepcopy(blk_dict)
360
+ return blk_dict
361
+
362
+ def get_transformed_region(self, img: np.ndarray, idx: int, textheight: int, maxwidth: int = None) -> np.ndarray :
363
+ im_h, im_w = img.shape[:2]
364
+
365
+ line = np.round(np.array(self.lines[idx])).astype(np.int64)
366
+
367
+ if not self.src_is_vertical and self.det_model == 'ctd':
368
+ # ctd detected horizontal bbox is smaller than GT
369
+ expand_size = max(int(self._detected_font_size * 0.1), 3)
370
+ rad = np.deg2rad(self.angle)
371
+ shifted_vec = np.array([[[-1, -1],[1, -1],[1, 1],[-1, 1]]])
372
+ shifted_vec = shifted_vec * np.array([[[np.sin(rad), np.cos(rad)]]]) * expand_size
373
+ line = line + shifted_vec
374
+ line[..., 0] = np.clip(line[..., 0], 0, im_w)
375
+ line[..., 1] = np.clip(line[..., 1], 0, im_h)
376
+ line = np.round(line[0]).astype(np.int64)
377
+
378
+ x1, y1, x2, y2 = line[:, 0].min(), line[:, 1].min(), line[:, 0].max(), line[:, 1].max()
379
+
380
+ x1 = np.clip(x1, 0, im_w)
381
+ y1 = np.clip(y1, 0, im_h)
382
+ x2 = np.clip(x2, 0, im_w)
383
+ y2 = np.clip(y2, 0, im_h)
384
+ img_croped = img[y1: y2, x1: x2]
385
+
386
+ direction = 'v' if self.src_is_vertical else 'h'
387
+
388
+ src_pts = line.copy()
389
+ src_pts[:, 0] -= x1
390
+ src_pts[:, 1] -= y1
391
+ middle_pnt = (src_pts[[1, 2, 3, 0]] + src_pts) / 2
392
+ vec_v = middle_pnt[2] - middle_pnt[0] # vertical vectors of textlines
393
+ vec_h = middle_pnt[1] - middle_pnt[3] # horizontal vectors of textlines
394
+ norm_v = np.linalg.norm(vec_v)
395
+ norm_h = np.linalg.norm(vec_h)
396
+
397
+ if textheight is None:
398
+ if direction == 'h' :
399
+ textheight = int(norm_v)
400
+ else:
401
+ textheight = int(norm_h)
402
+
403
+ if norm_v <= 0 or norm_h <= 0:
404
+ print('invalid textpolygon to target img')
405
+ return np.zeros((textheight, textheight, 3), dtype=np.uint8)
406
+ ratio = norm_v / norm_h
407
+
408
+ if direction == 'h' :
409
+ h = int(textheight)
410
+ w = int(round(textheight / ratio))
411
+ dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32)
412
+ M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
413
+ if M is None:
414
+ print('invalid textpolygon to target img')
415
+ return np.zeros((textheight, textheight, 3), dtype=np.uint8)
416
+ region = cv2.warpPerspective(img_croped, M, (w, h))
417
+ elif direction == 'v' :
418
+ w = int(textheight)
419
+ h = int(round(textheight * ratio))
420
+ dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32)
421
+ M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
422
+ if M is None:
423
+ print('invalid textpolygon to target img')
424
+ return np.zeros((textheight, textheight, 3), dtype=np.uint8)
425
+ region = cv2.warpPerspective(img_croped, M, (w, h))
426
+ region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE)
427
+
428
+ if maxwidth is not None:
429
+ h, w = region.shape[: 2]
430
+ if w > maxwidth:
431
+ region = cv2.resize(region, (maxwidth, h))
432
+
433
+ return region
434
+
435
+ def get_text(self) -> str:
436
+ if isinstance(self.text, str):
437
+ return self.text
438
+ text = ''
439
+ for t in self.text:
440
+ if text and t:
441
+ if text[-1].isalpha() and t[0].isalpha() \
442
+ and CJKPATTERN.search(text[-1]) is None \
443
+ and CJKPATTERN.search(t[0]) is None:
444
+ text += ' '
445
+ text += t
446
+
447
+ return text.strip()
448
+
449
+ def set_font_colors(self, fg_colors = None, bg_colors = None):
450
+ if fg_colors is not None:
451
+ self.fg_colors = fg_colors
452
+ if bg_colors is not None:
453
+ self.bg_colors = bg_colors
454
+
455
+ def update_font_colors(self, fg_colors: np.ndarray, bg_colors: np.ndarray):
456
+ nlines = len(self)
457
+ if nlines > 0:
458
+ if not isinstance(fg_colors, np.ndarray):
459
+ fg_colors = np.array(fg_colors, dtype=np.float32)
460
+ if not isinstance(bg_colors, np.ndarray):
461
+ bg_colors = np.array(bg_colors, dtype=np.float32)
462
+ if not isinstance(self.fg_colors, np.ndarray):
463
+ self.fg_colors = np.array(self.fg_colors, dtype=np.float32)
464
+ if not isinstance(self.bg_colors, np.ndarray):
465
+ self.bg_colors = np.array(self.bg_colors, dtype=np.float32)
466
+ self.fg_colors += fg_colors / nlines
467
+ self.bg_colors += bg_colors / nlines
468
+
469
+ def get_font_colors(self, bgr=False):
470
+
471
+ frgb = np.array(self.fg_colors).astype(np.int32)
472
+ brgb = np.array(self.bg_colors).astype(np.int32)
473
+
474
+ if bgr:
475
+ frgb = frgb[::-1]
476
+ brgb = brgb[::-1]
477
+
478
+ return frgb, brgb
479
+
480
+ def xywh(self):
481
+ x, y, w, h = self.xyxy
482
+ return [x, y, w-x, h-y]
483
+
484
+ def recalulate_alignment(self):
485
+ angled, center, polygons = self.unrotated_polygons()
486
+ polygons = polygons.reshape(-1, 4, 2)
487
+
488
+ left_std = np.std(polygons[:, 0, 0])
489
+ right_std = np.std(polygons[:, 1, 0])
490
+ center_std = np.std((polygons[:, 0, 0] + polygons[:, 1, 0]) / 2) * 0.7
491
+
492
+ if left_std < right_std and left_std < center_std:
493
+ self.alignment = TextAlignment.Left
494
+ elif right_std < left_std and right_std < center_std:
495
+ self.alignment = TextAlignment.Right
496
+ else:
497
+ self.alignment = TextAlignment.Center
498
+
499
+ def recalulate_stroke_width(self, color_diff_tol = 15, stroke_width: float = 0.2):
500
+ if color_difference(*self.get_font_colors()) < color_diff_tol:
501
+ self.stroke_width = 0.
502
+ else:
503
+ self.stroke_width = stroke_width
504
+
505
+ def adjust_pos(self, dx: int, dy: int):
506
+ self.xyxy[0] += dx
507
+ self.xyxy[1] += dy
508
+ self.xyxy[2] += dx
509
+ self.xyxy[3] += dy
510
+ if self._bounding_rect is not None:
511
+ self._bounding_rect[0] += dx
512
+ self._bounding_rect[1] += dy
513
+
514
+ def line_coord_valid(self, rect):
515
+ if self.det_model is None:
516
+ return False
517
+ if rect is None:
518
+ rect = self.bounding_rect()
519
+
520
+ min_bbox = self.min_rect(rotate_back=True)[0]
521
+ x1, y1 = min_bbox[0]
522
+ x2, y2 = min_bbox[2]
523
+ w = x2 - x1
524
+ h = y2 - y1
525
+ if w < 1 or h < 1:
526
+ return False
527
+ rx1, ry1, rx2, ry2 = rect
528
+ rx2 += rx1
529
+ ry2 += ry1
530
+ intersect = max(min(x2, rx2) - max(x1, rx1), 0) * max(min(y2, ry2) - max(y1, ry1), 0)
531
+ if intersect == 0:
532
+ return False
533
+ if intersect / (w * h) < 0.6:
534
+ return False
535
+ return True
536
+
537
+
538
+ def sort_regions(regions: List[TextBlock], right_to_left=None) -> List[TextBlock]:
539
+ # from manga image translator
540
+ # Sort regions from right to left, top to bottom
541
+
542
+ nr = len(regions)
543
+ if right_to_left is None and nr > 0:
544
+ nv = 0
545
+ for r in regions:
546
+ if r.vertical:
547
+ nv += 1
548
+ right_to_left = nv / nr > 0
549
+
550
+ sorted_regions = []
551
+ for region in sorted(regions, key=lambda region: region.center()[1]):
552
+ for i, sorted_region in enumerate(sorted_regions):
553
+ if region.center()[1] > sorted_region.xyxy[3]:
554
+ continue
555
+ if region.center()[1] < sorted_region.xyxy[1]:
556
+ sorted_regions.insert(i + 1, region)
557
+ break
558
+
559
+ # y center of region inside sorted_region so sort by x instead
560
+ if right_to_left and region.center()[0] > sorted_region.center()[0]:
561
+ sorted_regions.insert(i, region)
562
+ break
563
+ if not right_to_left and region.center()[0] < sorted_region.center()[0]:
564
+ sorted_regions.insert(i, region)
565
+ break
566
+ else:
567
+ sorted_regions.append(region)
568
+ return sorted_regions
569
+
570
+
571
+ def examine_textblk(blk: TextBlock, im_w: int, im_h: int, sort: bool = False) -> None:
572
+ lines = blk.lines_array()
573
+ middle_pnts = (lines[:, [1, 2, 3, 0]] + lines) / 2
574
+ vec_v = middle_pnts[:, 2] - middle_pnts[:, 0] # vertical vectors of textlines
575
+ vec_h = middle_pnts[:, 1] - middle_pnts[:, 3] # horizontal vectors of textlines
576
+ # if sum of vertical vectors is longer, then text orientation is vertical, and vice versa.
577
+ center_pnts = (lines[:, 0] + lines[:, 2]) / 2
578
+ v = np.sum(vec_v, axis=0)
579
+ h = np.sum(vec_h, axis=0)
580
+ norm_v, norm_h = np.linalg.norm(v), np.linalg.norm(h)
581
+ vertical = blk.src_is_vertical
582
+ # calcuate distance between textlines and origin
583
+ if vertical:
584
+ primary_vec, primary_norm = v, norm_v
585
+ distance_vectors = center_pnts - np.array([[im_w, 0]], dtype=np.float64) # vertical manga text is read from right to left, so origin is (imw, 0)
586
+ font_size = int(round(norm_h / len(lines)))
587
+ else:
588
+ primary_vec, primary_norm = h, norm_h
589
+ distance_vectors = center_pnts - np.array([[0, 0]], dtype=np.float64)
590
+ font_size = int(round(norm_v / len(lines)))
591
+
592
+ rotation_angle = int(math.atan2(primary_vec[1], primary_vec[0]) / math.pi * 180) # rotation angle of textlines
593
+ distance = np.linalg.norm(distance_vectors, axis=1) # distance between textlinecenters and origin
594
+ rad_matrix = np.arccos(np.einsum('ij, j->i', distance_vectors, primary_vec) / (distance * primary_norm))
595
+ distance = np.abs(np.sin(rad_matrix) * distance)
596
+ blk.lines = lines.astype(np.int32).tolist()
597
+ blk.distance = distance
598
+ blk.angle = rotation_angle
599
+ if vertical:
600
+ blk.angle -= 90
601
+ if abs(blk.angle) < 3:
602
+ blk.angle = 0
603
+ blk.font_size = font_size
604
+ blk.vec = primary_vec
605
+ blk.norm = primary_norm
606
+ if sort:
607
+ blk.sort_lines()
608
+
609
+ def try_merge_textline(blk: TextBlock, blk2: TextBlock, fntsize_tol=1.7, distance_tol=2) -> bool:
610
+ if blk2.merged:
611
+ return False
612
+ fntsize_div = blk.font_size / blk2.font_size
613
+ num_l1, num_l2 = len(blk), len(blk2)
614
+ fntsz_avg = (blk.font_size * num_l1 + blk2.font_size * num_l2) / (num_l1 + num_l2)
615
+ vec_prod = blk.vec @ blk2.vec
616
+ vec_sum = blk.vec + blk2.vec
617
+ cos_vec = vec_prod / blk.norm / blk2.norm
618
+ # distance = blk2.distance[-1] - blk.distance[-1]
619
+ # distance_p1 = np.linalg.norm(np.array(blk2.lines[-1][0]) - np.array(blk.lines[-1][0]))
620
+ minrect1 = blk.min_rect(ids=[-1])[0]
621
+ xyxy1 = [*minrect1[0], *minrect1[2]]
622
+ minrect2 = blk2.min_rect(ids=[-1])[0]
623
+ xyxy2 = [*minrect2[0], *minrect2[2]]
624
+ distance_x = max(xyxy1[0], xyxy2[0]) - min(xyxy1[2], xyxy2[2])
625
+ distance_y = max(xyxy1[1], xyxy2[1]) - min(xyxy1[3], xyxy2[3])
626
+
627
+ l1, l2 = Polygon(blk.lines[-1]), Polygon(blk2.lines[-1])
628
+ if not l1.intersects(l2):
629
+ if blk.vertical:
630
+ if distance_y > 0:
631
+ return False
632
+ else:
633
+ if distance_x > 0:
634
+ return False
635
+ if fntsize_div > fntsize_tol or 1 / fntsize_div > fntsize_tol:
636
+ return False
637
+ if abs(cos_vec) < 0.866: # cos30
638
+ return False
639
+ # if distance > distance_tol * fntsz_avg:
640
+ # return False
641
+ if blk.vertical and blk2.vertical and distance_x > fntsz_avg * 0.8:
642
+ return False
643
+ if not blk.vertical and distance_y > fntsz_avg * 0.5:
644
+ return False
645
+ # merge
646
+ for line in blk2.lines:
647
+ blk.lines.append(line)
648
+ blk.vec = vec_sum
649
+ blk.angle = int(round(np.rad2deg(math.atan2(vec_sum[1], vec_sum[0]))))
650
+ if blk.vertical:
651
+ blk.angle -= 90
652
+ blk.norm = np.linalg.norm(vec_sum)
653
+ blk.distance = np.append(blk.distance, blk2.distance[-1])
654
+ blk.font_size = fntsz_avg
655
+ blk2.merged = True
656
+ return True
657
+
658
+ def merge_textlines(blk_list: List[TextBlock]) -> List[TextBlock]:
659
+ if len(blk_list) < 2:
660
+ return blk_list
661
+ blk_list.sort(key=lambda blk: blk.distance[0])
662
+ merged_list = []
663
+ for ii, current_blk in enumerate(blk_list):
664
+ if current_blk.merged:
665
+ continue
666
+ for jj, blk in enumerate(blk_list[ii+1:]):
667
+ try_merge_textline(current_blk, blk)
668
+ merged_list.append(current_blk)
669
+ for blk in merged_list:
670
+ blk.adjust_bbox(with_bbox=False)
671
+ return merged_list
672
+
673
+ def split_textblk(blk: TextBlock):
674
+ font_size, distance, lines = blk.font_size, blk.distance, blk.lines
675
+ l0 = np.array(blk.lines[0])
676
+ lines.sort(key=lambda line: np.linalg.norm(np.array(line[0]) - l0[0]))
677
+ distance_tol = font_size * 2
678
+ current_blk = copy.deepcopy(blk)
679
+ current_blk.lines = [l0]
680
+ sub_blk_list = [current_blk]
681
+ textblock_splitted = False
682
+ for jj, line in enumerate(lines[1:]):
683
+ l1, l2 = Polygon(lines[jj]), Polygon(line)
684
+ split = False
685
+ if not l1.intersects(l2):
686
+ line_disance = abs(distance[jj+1] - distance[jj])
687
+ if line_disance > distance_tol:
688
+ split = True
689
+ elif blk.vertical and abs(blk.angle) < 15:
690
+ if len(current_blk.lines) > 1 or line_disance > font_size:
691
+ split = abs(lines[jj][0][1] - line[0][1]) > font_size
692
+ if split:
693
+ current_blk = copy.deepcopy(current_blk)
694
+ current_blk.lines = [line]
695
+ sub_blk_list.append(current_blk)
696
+ else:
697
+ current_blk.lines.append(line)
698
+ if len(sub_blk_list) > 1:
699
+ textblock_splitted = True
700
+ for current_blk in sub_blk_list:
701
+ current_blk.adjust_bbox(with_bbox=False)
702
+ return textblock_splitted, sub_blk_list
703
+
704
+ def group_output(blks, lines, im_w, im_h, mask=None, sort_blklist=True, canvas=None) -> List[TextBlock]:
705
+ blk_list: List[TextBlock] = []
706
+ scattered_lines = {'ver': [], 'hor': []}
707
+ for bbox, cls, conf in zip(*blks):
708
+ # cls could give wrong result
709
+ blk_list.append(TextBlock(bbox, language=LANG_LIST[cls]))
710
+
711
+ # step1: filter & assign lines to textblocks
712
+ bbox_score_thresh = 0.4
713
+ mask_score_thresh = 0.1
714
+ for ii, line in enumerate(lines):
715
+ line, is_vertical = sort_pnts(line)
716
+ bx1, bx2 = line[:, 0].min(), line[:, 0].max()
717
+ by1, by2 = line[:, 1].min(), line[:, 1].max()
718
+ bbox_score, bbox_idx = -1, -1
719
+ line_area = (by2-by1) * (bx2-bx1)
720
+ for jj, blk in enumerate(blk_list):
721
+ score = union_area(blk.xyxy, [bx1, by1, bx2, by2]) / line_area
722
+ if bbox_score < score:
723
+ bbox_score = score
724
+ bbox_idx = jj
725
+ if bbox_score > bbox_score_thresh:
726
+ blk_list[bbox_idx].lines.append(line)
727
+ blk_list[bbox_idx].adjust_bbox(with_bbox=True)
728
+ else: # if no textblock was assigned, check whether there is "enough" textmask
729
+ if mask is not None:
730
+ mask_score = mask[by1: by2, bx1: bx2].mean() / 255
731
+ if mask_score < mask_score_thresh:
732
+ continue
733
+ blk = TextBlock([bx1, by1, bx2, by2], [line])
734
+ blk.vertical = blk.src_is_vertical = is_vertical
735
+ examine_textblk(blk, im_w, im_h, sort=False)
736
+ if blk.vertical:
737
+ scattered_lines['ver'].append(blk)
738
+ else:
739
+ scattered_lines['hor'].append(blk)
740
+
741
+ # step2: filter textblocks, sort & split textlines
742
+ final_blk_list = []
743
+ for blk in blk_list:
744
+ # filter textblocks
745
+ if len(blk.lines) == 0:
746
+ bx1, by1, bx2, by2 = blk.xyxy
747
+ if mask is not None:
748
+ mask_score = mask[by1: by2, bx1: bx2].mean() / 255
749
+ if mask_score < mask_score_thresh:
750
+ continue
751
+ xywh = np.array([[bx1, by1, bx2-bx1, by2-by1]])
752
+ blk.lines = xywh2xyxypoly(xywh).reshape(-1, 4, 2).tolist()
753
+ else:
754
+ blk.adjust_bbox(with_bbox=False)
755
+ examine_textblk(blk, im_w, im_h, sort=True)
756
+
757
+ # split manga text if there is a distance gap
758
+ textblock_splitted = False
759
+ if len(blk.lines) > 1:
760
+ if blk.language == 'ja':
761
+ textblock_splitted = True
762
+ elif blk.vertical:
763
+ textblock_splitted = True
764
+ # if textblock_splitted:
765
+ # textblock_splitted, sub_blk_list = split_textblk(blk)
766
+ # else:
767
+ sub_blk_list = [blk]
768
+ # modify textblock to fit its textlines
769
+ if not textblock_splitted:
770
+ for blk in sub_blk_list:
771
+ blk.adjust_bbox(with_bbox=True)
772
+ final_blk_list += sub_blk_list
773
+
774
+ _final_blk_list = []
775
+ for blk in final_blk_list:
776
+ if blk.vertical:
777
+ scattered_lines['ver'].append(blk)
778
+ else:
779
+ _final_blk_list.append(blk)
780
+ final_blk_list = _final_blk_list
781
+ # step3: merge scattered lines, sort textblocks by "grid"
782
+ final_blk_list += merge_textlines(scattered_lines['hor'])
783
+ final_blk_list += merge_textlines(scattered_lines['ver'])
784
+ if sort_blklist:
785
+ final_blk_list = sort_regions(final_blk_list, )
786
+ for blk in final_blk_list:
787
+ blk.distance = None
788
+
789
+
790
+ if len(final_blk_list) > 1:
791
+ _final_blks = [final_blk_list[0]]
792
+ for blk in final_blk_list[1:]:
793
+ ax1, ay1, ax2, ay2 = blk.xyxy
794
+ keep_blk = True
795
+ aarea = (ax2 - ax1) * (ay2 - ay1) + 1e-6
796
+ for eb in _final_blks:
797
+ bx1, by1, bx2, by2 = eb.xyxy
798
+ x1 = max(ax1, bx1)
799
+ y1 = max(ay1, by1)
800
+ x2 = min(ax2, bx2)
801
+ y2 = min(ay2, by2)
802
+ if y2 < y1 or x2 < x1:
803
+ continue
804
+ inter_area = (y2 - y1) * (x2 - x1)
805
+ if inter_area / aarea > 0.9:
806
+ keep_blk = False
807
+ break
808
+ if keep_blk:
809
+ _final_blks.append(blk)
810
+ final_blk_list = _final_blks
811
+
812
+ for blk in final_blk_list:
813
+ if blk.language != 'ja' and not blk.vertical:
814
+ num_lines = len(blk.lines)
815
+ if num_lines == 0:
816
+ continue
817
+ blk._detected_font_size = blk.font_size
818
+
819
+ return final_blk_list
820
+
821
+ def visualize_textblocks(canvas, blk_list: List[TextBlock]):
822
+ lw = max(round(sum(canvas.shape) / 2 * 0.003), 2) # line width
823
+ for ii, blk in enumerate(blk_list):
824
+ bx1, by1, bx2, by2 = blk.xyxy
825
+ cv2.rectangle(canvas, (bx1, by1), (bx2, by2), (127, 255, 127), lw)
826
+ lines = blk.lines_array(dtype=np.int32)
827
+ for jj, line in enumerate(lines):
828
+ cv2.putText(canvas, str(jj), line[0], cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,127,0), 1)
829
+ cv2.polylines(canvas, [line], True, (0,127,255), 2)
830
+ cv2.polylines(canvas, [blk.min_rect()], True, (127,127,0), 2)
831
+ center = [int((bx1 + bx2)/2), int((by1 + by2)/2)]
832
+ cv2.putText(canvas, str(blk.angle), center, cv2.FONT_HERSHEY_SIMPLEX, 1, (127,127,255), 2)
833
+ cv2.putText(canvas, str(ii), (bx1, by1 + lw + 2), 0, lw / 3, (255,127,127), max(lw-1, 1), cv2.LINE_AA)
834
+ return canvas
835
+
836
+ def collect_textblock_regions(img: np.ndarray, textblk_lst: List[TextBlock], text_height=48, maxwidth=8100, split_textblk = False, seg_func: Callable = None):
837
+ regions = []
838
+ textblk_lst_indices = []
839
+ for blk_idx, textblk in enumerate(textblk_lst):
840
+ for ii in range(len(textblk)):
841
+ if split_textblk and len(textblk) == 1:
842
+ seg_func = canny_flood
843
+ region = textblk.get_transformed_region(img, ii, None, maxwidth=None)
844
+ mask = seg_func(region)[0]
845
+ split_lines = split_text_region(mask)[0]
846
+ for jj, line in enumerate(split_lines):
847
+ bottom = line[3]
848
+ if len(split_lines) == 1:
849
+ bottom = region.shape[0]
850
+ r = region[line[1]: bottom]
851
+ h, w = r.shape[:2]
852
+ tgt_h, tgt_w = text_height, min(maxwidth, int(text_height / h * w))
853
+ if tgt_h != h or tgt_w != w:
854
+ r = cv2.resize(r, (tgt_w, tgt_h), interpolation=cv2.INTER_LINEAR)
855
+ regions.append(r)
856
+ textblk_lst_indices.append(blk_idx)
857
+ # cv2.imwrite(f'local_region{jj}.jpg', r)
858
+ # cv2.imwrite('local_mask.jpg', mask)
859
+ # cv2.imwrite('local_region.jpg',region)
860
+ else:
861
+ textblk_lst_indices.append(blk_idx)
862
+ region = textblk.get_transformed_region(img, ii, text_height, maxwidth=maxwidth)
863
+ regions.append(region)
864
+
865
+ return regions, textblk_lst_indices
866
+
867
+
868
+ def mit_merge_textlines(textlines: List[Quadrilateral], width: int, height: int, verbose: bool = False) -> List[TextBlock]:
869
+ # from https://github.com/zyddnys/manga-image-translator
870
+ quadrilateral_lst = []
871
+ for line in textlines:
872
+ if not isinstance(line, Quadrilateral):
873
+ line = Quadrilateral(np.array(line), '', 1.)
874
+ quadrilateral_lst.append(line)
875
+ textlines = quadrilateral_lst
876
+
877
+ text_regions: List[TextBlock] = []
878
+ textlines_total_area = sum([txtln.area for txtln in textlines])
879
+ for (txtlns, fg_color, bg_color) in merge_bboxes_text_region(textlines, width, height):
880
+ total_logprobs = 0
881
+ for txtln in txtlns:
882
+ total_logprobs += np.log(txtln.prob) * txtln.area
883
+
884
+ total_logprobs /= textlines_total_area
885
+ font_size = int(min([txtln.font_size for txtln in txtlns]))
886
+ angle = np.rad2deg(np.mean([txtln.angle for txtln in txtlns])) - 90
887
+ if abs(angle) < 3:
888
+ angle = 0
889
+ lines = [txtln.pts for txtln in txtlns]
890
+ texts = [txtln.text for txtln in txtlns]
891
+ ffmt = FontFormat(font_size=font_size, frgb=fg_color, srgb=bg_color)
892
+
893
+ nv = 0
894
+ for txtln in txtlns:
895
+ if txtln.direction == 'v':
896
+ nv += 1
897
+ is_vertical = nv >= len(txtlns) // 2
898
+ region = TextBlock(
899
+ lines=lines, text=texts, angle=angle, fontformat=ffmt,
900
+ _detected_font_size=font_size, src_is_vertical=is_vertical, vertical=is_vertical)
901
+ region.adjust_bbox()
902
+ if region.src_is_vertical:
903
+ region.alignment = 1
904
+ else:
905
+ region.recalulate_alignment()
906
+ text_regions.append(region)
907
+
908
+ return text_regions
utils/textblock_mask.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from typing import Tuple
4
+ from .imgproc_utils import draw_connected_labels
5
+ from .stroke_width_calculator import strokewidth_check
6
+
7
+ opencv_inpaint = lambda img, mask: cv2.inpaint(img, mask, 3, cv2.INPAINT_NS)
8
+
9
+ def show_img_by_dict(imgdicts):
10
+ for keyname in imgdicts.keys():
11
+ cv2.imshow(keyname, imgdicts[keyname])
12
+ cv2.waitKey(0)
13
+
14
+ # 计算文本rgb均值
15
+ def letter_calculator(img, mask, bground_rgb, show_process=False):
16
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
17
+ # rgb to grey
18
+ aver_bground_rgb = 0.299 * bground_rgb[0] + 0.587 * bground_rgb[1] + 0.114 * bground_rgb[2]
19
+ thresh_low = 127
20
+ retval, threshed = cv2.threshold(gray, 127, 255, cv2.THRESH_OTSU)
21
+
22
+ if aver_bground_rgb < thresh_low:
23
+ threshed = 255 - threshed
24
+ threshed = 255 - threshed
25
+
26
+
27
+ threshed = cv2.bitwise_and(threshed, mask)
28
+ le_region = np.where(threshed==255)
29
+ mat_region = img[le_region]
30
+
31
+ if mat_region.shape[0] == 0:
32
+ # retval, threshed = cv2.threshold(gray, 20, 255, cv2.THRESH_BINARY)
33
+ # cv2.imshow("xxx", threshed)
34
+ # cv2.imshow("2xxx", img)
35
+ # cv2.waitKey(0)
36
+ return [-1, -1, -1], threshed
37
+
38
+ letter_rgb = np.mean(mat_region, axis=0).astype(int).tolist()
39
+
40
+ if show_process:
41
+ cv2.imshow("thresh", threshed)
42
+ # ocr_protest(threshed)
43
+ imgcp = np.copy(img)
44
+ imgcp *= 0
45
+ imgcp += 127
46
+ imgcp[le_region] = letter_rgb
47
+ cv2.imshow("letter_img", imgcp)
48
+ # cv2.waitKey(0)
49
+
50
+ return letter_rgb, threshed
51
+
52
+ # 预处理让文本颜色提取准确点
53
+ def usm(src):
54
+ blur_img = cv2.GaussianBlur(src, (0, 0), 5)
55
+ usm = cv2.addWeighted(src, 1.5, blur_img, -0.5, 0)
56
+ h, w = src.shape[:2]
57
+ result = np.zeros([h, w*2, 3], dtype=src.dtype)
58
+ result[0:h,0:w,:] = src
59
+ result[0:h,w:2*w,:] = usm
60
+ return usm
61
+
62
+ # 计算文本rgb均值方法2,可能用中位数代替均值会好点
63
+ def textrgb_calculator(img, text_mask, show_process=False):
64
+ text_mask = cv2.erode(text_mask, (3, 3), iterations=1)
65
+ usm_img = usm(img)
66
+ overall_meanrgb = np.mean(usm_img[np.where(text_mask==255)], axis=0)
67
+ if show_process:
68
+ colored_text_board = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) + 127
69
+ colored_text_board[np.where(text_mask==255)] = overall_meanrgb
70
+ cv2.imshow("usm", usm_img)
71
+ cv2.imshow("textcolor", colored_text_board)
72
+ return overall_meanrgb.astype(np.uint8)
73
+
74
+ # 计算背景rgb均值和标准差
75
+ def bground_calculator(buble_img, back_ground_mask, dilate=True):
76
+ kernel = np.ones((3,3),np.uint8)
77
+ if dilate:
78
+ back_ground_mask = cv2.dilate(back_ground_mask, kernel, iterations = 1)
79
+ bground_region = np.where(back_ground_mask==0)
80
+ sd = -1
81
+ if len(bground_region[0]) != 0:
82
+ pix_array = buble_img[bground_region]
83
+ bground_aver = np.mean(pix_array, axis=0).astype(int)
84
+ pix_array - bground_aver
85
+ gray = cv2.cvtColor(buble_img, cv2.COLOR_RGB2GRAY)
86
+ gray_pixarray = gray[bground_region]
87
+ gray_aver = np.mean(gray_pixarray)
88
+ gray_pixarray = gray_pixarray - gray_aver
89
+ gray_pixarray = np.power(gray_pixarray, 2)
90
+ # gray_pixarray = np.sqrt(gray_pixarray)
91
+ sd = np.mean(gray_pixarray)
92
+ else: bground_aver = np.array([-1, -1, -1])
93
+
94
+ return bground_aver, bground_region, sd
95
+
96
+ # 输入:文本块roi,分割出文本mask,根据mask计算文本bgr均值和标准差,决定纯色覆盖/inpaint修复
97
+ def canny_flood(img, show_process=False, inpaint_sdthresh=10, **kwargs):
98
+ # cv2.setNumThreads(4)
99
+ WHITE = (255, 255, 255)
100
+ BLACK = (0, 0, 0)
101
+ kernel = np.ones((3,3),np.uint8)
102
+ orih, oriw = img.shape[0], img.shape[1]
103
+ scaleR = 1
104
+ if orih > 300 and oriw > 300:
105
+ scaleR = 0.6
106
+ elif orih < 120 or oriw < 120:
107
+ scaleR = 1.4
108
+
109
+ if scaleR != 1:
110
+ h, w = img.shape[0], img.shape[1]
111
+ orimg = np.copy(img)
112
+ img = cv2.resize(img, (int(w*scaleR), int(h*scaleR)), interpolation=cv2.INTER_AREA)
113
+ h, w = img.shape[0], img.shape[1]
114
+ img_area = h * w
115
+
116
+ cpimg = cv2.GaussianBlur(img,(3,3),cv2.BORDER_DEFAULT)
117
+ detected_edges = cv2.Canny(cpimg, 70, 140, L2gradient=True, apertureSize=3)
118
+ cv2.rectangle(detected_edges, (0, 0), (w-1, h-1), WHITE, 1, cv2.LINE_8)
119
+
120
+ cons, hiers = cv2.findContours(detected_edges, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
121
+
122
+ cv2.rectangle(detected_edges, (0, 0), (w-1, h-1), BLACK, 1, cv2.LINE_8)
123
+
124
+ ballon_mask, outer_index = np.zeros((h, w), np.uint8), -1
125
+
126
+ min_retval = np.inf
127
+ mask = np.zeros((h, w), np.uint8)
128
+ difres = 10
129
+ seedpnt = (int(w/2), int(h/2))
130
+ for ii in range(len(cons)):
131
+ rect = cv2.boundingRect(cons[ii])
132
+ if rect[2]*rect[3] < img_area*0.4:
133
+ continue
134
+
135
+ mask = cv2.drawContours(mask, cons, ii, (255), 2)
136
+ cpmask = np.copy(mask)
137
+ cv2.rectangle(mask, (0, 0), (w-1, h-1), WHITE, 1, cv2.LINE_8)
138
+ retval, _, _, rect = cv2.floodFill(cpmask, mask=None, seedPoint=seedpnt, flags=4, newVal=(127), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
139
+
140
+ if retval <= img_area * 0.3:
141
+ mask = cv2.drawContours(mask, cons, ii, (0), 2)
142
+ if retval < min_retval and retval > img_area * 0.3:
143
+ min_retval = retval
144
+ ballon_mask = cpmask
145
+
146
+ ballon_mask = 127 - ballon_mask
147
+ ballon_mask = cv2.dilate(ballon_mask, kernel,iterations = 1)
148
+ outer_area, _, _, rect = cv2.floodFill(ballon_mask, mask=None, seedPoint=seedpnt, flags=4, newVal=(30), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
149
+ ballon_mask = 30 - ballon_mask
150
+ retval, ballon_mask = cv2.threshold(ballon_mask, 1, 255, cv2.THRESH_BINARY)
151
+ ballon_mask = cv2.bitwise_not(ballon_mask, ballon_mask)
152
+
153
+ detected_edges = cv2.dilate(detected_edges, kernel, iterations = 1)
154
+ for ii in range(2):
155
+ detected_edges = cv2.bitwise_and(detected_edges, ballon_mask)
156
+ mask = np.copy(detected_edges)
157
+ bgarea1, _, _, rect = cv2.floodFill(mask, mask=None, seedPoint=(0, 0), flags=4, newVal=(127), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
158
+ bgarea2, _, _, rect = cv2.floodFill(mask, mask=None, seedPoint=(detected_edges.shape[1]-1, detected_edges.shape[0]-1), flags=4, newVal=(127), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
159
+ txt_area = min(img_area - bgarea1, img_area - bgarea2)
160
+ ratio_ob = txt_area / outer_area
161
+ ballon_mask = cv2.erode(ballon_mask, kernel,iterations = 1)
162
+ if ratio_ob < 0.85:
163
+ break
164
+
165
+ mask = 127 - mask
166
+ retval, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
167
+ if scaleR != 1:
168
+ img = orimg
169
+ ballon_mask = cv2.resize(ballon_mask, (oriw, orih))
170
+ mask = cv2.resize(mask, (oriw, orih))
171
+
172
+ bg_mask = cv2.bitwise_or(mask, 255-ballon_mask)
173
+ mask = cv2.bitwise_and(mask, ballon_mask)
174
+
175
+ bground_aver, bground_region, sd = bground_calculator(img, bg_mask)
176
+ inner_rect = None
177
+ threshed = np.zeros((img.shape[0], img.shape[1]), np.uint8)
178
+
179
+ if bground_aver[0] != -1:
180
+ letter_aver, threshed = letter_calculator(img, mask, bground_aver, show_process=show_process)
181
+ if letter_aver[0] != -1:
182
+ mask = cv2.dilate(threshed, kernel, iterations=1)
183
+ inner_rect = cv2.boundingRect(cv2.findNonZero(mask))
184
+ else: letter_aver = [0, 0, 0]
185
+
186
+ if sd != -1 and sd < inpaint_sdthresh:
187
+ need_inpaint = False
188
+ else:
189
+ need_inpaint = True
190
+ if show_process:
191
+ print(f"\nneed_inpaint: {need_inpaint}, sd: {sd}, {type(inner_rect)}")
192
+ show_img_by_dict({"outermask": ballon_mask, "detect": detected_edges, "mask": mask})
193
+
194
+
195
+ if isinstance(inner_rect, tuple):
196
+ inner_rect = [ii for ii in inner_rect]
197
+ if inner_rect is None:
198
+ inner_rect = [-1, -1, -1, -1]
199
+ else:
200
+ inner_rect.append(-1)
201
+
202
+ bground_aver = bground_aver.astype(np.uint8)
203
+ bub_dict = {"rgb": letter_aver,
204
+ "bground_rgb": bground_aver,
205
+ "inner_rect": inner_rect,
206
+ "need_inpaint": need_inpaint}
207
+ return mask, ballon_mask, bub_dict
208
+
209
+ # 输入:文本块roi,分割出文本mask,根据mask计算文本bgr均值和标准差,决定纯色覆盖/inpaint修复
210
+ def connected_canny_flood(img, show_process=False, inpaint_sdthresh=10, apply_strokewidth_check=0, **kwargs):
211
+
212
+ # 寻找最可能是气泡的外轮廓mask
213
+ def find_outermask(img):
214
+ connectivity = 4
215
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img, connectivity, cv2.CV_16U)
216
+ drawtext = np.zeros((img.shape[0], img.shape[1]), np.uint8)
217
+
218
+ max_ind = np.argmax(stats[:, 4])
219
+ maxbbox_area, sec_ind = -1, -1
220
+ for ind, stat in enumerate(stats):
221
+ if ind != max_ind:
222
+ bbarea = stat[2] * stat[3]
223
+ if bbarea > maxbbox_area:
224
+ maxbbox_area = bbarea
225
+ sec_ind = ind
226
+ drawtext[np.where(labels==max_ind)] = 255
227
+
228
+ cv2.rectangle(drawtext, (0, 0), (img.shape[1]-1, img.shape[0]-1), (0, 0, 0), 1, cv2.LINE_8)
229
+ cons, hiers = cv2.findContours(drawtext, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
230
+ img_area = img.shape[0] * img.shape[1]
231
+
232
+ rects = np.array([cv2.boundingRect(cnt) for cnt in cons])
233
+ rect_area = np.array([rect[2] * rect[3] for rect in rects])
234
+ quali_ind = np.where(rect_area > img_area * 0.3)[0]
235
+ ballon_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8)
236
+ for ind in quali_ind:
237
+ ballon_mask = cv2.drawContours(ballon_mask, cons, ind, (255), 2)
238
+
239
+ seedpnt = (int(ballon_mask.shape[1]/2), int(ballon_mask.shape[0]/2))
240
+ difres = 10
241
+ retval, _, _, rect = cv2.floodFill(ballon_mask, mask=None, seedPoint=seedpnt, flags=4, newVal=(127), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
242
+ ballon_mask = 255 - cv2.threshold(ballon_mask - 127, 1, 255, cv2.THRESH_BINARY)[1]
243
+ return num_labels, labels, stats, centroids, ballon_mask
244
+
245
+ # BGR直接转灰度图可能导致文本区域和背景难以区分,比如测试样例中的黑底红字
246
+ # 但是总有一个通道文本和背景容易区分
247
+ # 返回最容易区分的那个通道
248
+ def ccctest(img, crop_r=0.1):
249
+ # img = usm(img)
250
+ maxh = 100
251
+ if img.shape[0] > maxh:
252
+ scaleR = maxh / img.shape[0]
253
+ im = cv2.resize(img, (int(img.shape[1]*scaleR), int(img.shape[0]*scaleR)), interpolation=cv2.INTER_AREA)
254
+ else:
255
+ im = img
256
+
257
+ textlabel_counter = 0
258
+ reverse = False
259
+ c_ind = 0
260
+
261
+ num_labels, labels, stats, centroids, pseduo_outermask = find_outermask(cv2.threshold(cv2.cvtColor(im, cv2.COLOR_RGB2GRAY), 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)[1])
262
+ grayim = np.expand_dims(np.array(cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)), axis=2)
263
+ im = np.append(im, grayim, axis=2)
264
+ outer_cords = np.where(pseduo_outermask==255)
265
+ for bgr_ind in range(4):
266
+ channel = im[:, :, bgr_ind]
267
+ ret, thresh = cv2.threshold(channel, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
268
+
269
+ tmp_reverse = False
270
+
271
+ if np.mean(thresh[outer_cords]) > 160:
272
+ thresh = 255 - thresh
273
+ tmp_reverse = True
274
+
275
+ num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(thresh, 4, cv2.CV_16U)
276
+ # draw_connected_labels(num_labels, labels, stats, centroids)
277
+ # cv2.waitKey(0)
278
+ max_ind = np.argmax(stats[:, 4])
279
+ maxr, minr = 0.5, 0.001
280
+ maxw, maxh = stats[max_ind][2] * maxr, stats[max_ind][3] * maxr
281
+ minarea = im.shape[0] * im.shape[1] * minr
282
+
283
+ tmp_counter = 0
284
+ for stat in stats:
285
+ bboxarea = stat[2] * stat[3]
286
+ if stat[2] < maxw and stat[3] < maxh and bboxarea > minarea:
287
+ tmp_counter += 1
288
+ if tmp_counter > textlabel_counter:
289
+ textlabel_counter = tmp_counter
290
+ c_ind = bgr_ind
291
+ reverse = tmp_reverse
292
+ return c_ind, reverse
293
+
294
+ channel_index, reverse = ccctest(img)
295
+ chanel = img[:, :, channel_index] if channel_index < 3 else cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
296
+ ret, thresh = cv2.threshold(chanel, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
297
+
298
+ # reverse to get white text on black bg
299
+ if reverse:
300
+ thresh = 255 - thresh
301
+ num_labels, labels, stats, centroids, ballon_mask = find_outermask(thresh)
302
+ img_area = img.shape[0] * img.shape[1]
303
+ text_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8)
304
+ max_ind = np.argmax(stats[:, 4])
305
+ for lab in (range(num_labels)):
306
+ stat = stats[lab]
307
+ if lab != max_ind and stat[4] < img_area * 0.4:
308
+ labcord = np.where(labels==lab)
309
+ text_mask[labcord] = 255
310
+
311
+ text_mask = cv2.bitwise_and(text_mask, ballon_mask)
312
+ if apply_strokewidth_check > 0:
313
+ text_mask = strokewidth_check(text_mask, labels, num_labels, stats, debug_type=show_process-1)
314
+
315
+ text_color = textrgb_calculator(img, text_mask, show_process=show_process)
316
+ inner_rect = cv2.boundingRect(cv2.findNonZero(cv2.dilate(text_mask, (3, 3), iterations=1)))
317
+ inner_rect = [ii for ii in inner_rect]
318
+ inner_rect.append(-1)
319
+
320
+ bg_mask = cv2.bitwise_or(text_mask, 255-ballon_mask)
321
+
322
+ bground_aver, bground_region, sd = bground_calculator(img, bg_mask)
323
+
324
+ mask = cv2.GaussianBlur(text_mask,(3,3),cv2.BORDER_DEFAULT)
325
+ _, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
326
+ if sd != -1 and sd < inpaint_sdthresh:
327
+ need_inpaint = False
328
+ else:
329
+ need_inpaint = True
330
+
331
+ if show_process:
332
+ print(f"\nuse inpaint: {need_inpaint}, sd: {sd}, {type(inner_rect)}")
333
+ draw_connected_labels(num_labels, labels, stats, centroids)
334
+ show_img_by_dict({"thresh": thresh, "ori": img, "outer": ballon_mask, "text": text_mask, "bgmask": bg_mask})
335
+
336
+ bground_aver = bground_aver.astype(np.uint8)
337
+ bub_dict = {"rgb": text_color,
338
+ "bground_rgb": bground_aver,
339
+ "inner_rect": inner_rect,
340
+ "need_inpaint": need_inpaint}
341
+ return mask, ballon_mask, bub_dict
342
+
343
+
344
+ def existing_mask(img, mask: np.ndarray):
345
+ bub_dict = {"rgb": [0, 0, 0],"bground_rgb": [255, 255, 255],"need_inpaint": True}
346
+ return mask, mask, bub_dict
347
+
348
+
349
+ def extract_ballon_mask(img: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
350
+ '''
351
+ Given original img and text mask (cropped)
352
+ return ballon mask & non text mask
353
+ '''
354
+ img = cv2.GaussianBlur(img,(3,3),cv2.BORDER_DEFAULT)
355
+ h, w = img.shape[:2]
356
+ text_sum = np.sum(mask)
357
+ cannyed = cv2.Canny(img, 70, 140, L2gradient=True, apertureSize=3)
358
+ e_size = 1
359
+ element = cv2.getStructuringElement(cv2.MORPH_RECT, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size))
360
+ cannyed = cv2.dilate(cannyed, element, iterations=1)
361
+ br = cv2.boundingRect(cv2.findNonZero(mask))
362
+ br_xyxy = [br[0], br[1], br[0] + br[2], br[1] + br[3]]
363
+
364
+ # draw the bounding rect in case there is no closed ballon
365
+ cv2.rectangle(cannyed, (0, 0), (w-1, h-1), (255, 255, 255), 1, cv2.LINE_8)
366
+ cannyed = cv2.bitwise_and(cannyed, 255 - mask)
367
+
368
+ cons, _ = cv2.findContours(cannyed, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
369
+ min_ballon_area = w * h
370
+ ballon_mask = None
371
+ non_text_mask = None
372
+ # minimum contour which covers all text mask must be the ballon
373
+ for ii, con in enumerate(cons):
374
+ br_c = cv2.boundingRect(con)
375
+ br_c = [br_c[0], br_c[1], br_c[0] + br_c[2], br_c[1] + br_c[3]]
376
+ if br_c[0] > br_xyxy[0] or br_c[1] > br_xyxy[1] or br_c[2] < br_xyxy[2] or br_c[3] < br_xyxy[3]:
377
+ continue
378
+ tmp = np.zeros_like(cannyed)
379
+ cv2.drawContours(tmp, cons, ii, (255, 255, 255), -1, cv2.LINE_8)
380
+ if cv2.bitwise_and(tmp, mask).sum() >= text_sum:
381
+ con_area = cv2.contourArea(con)
382
+ if con_area < min_ballon_area:
383
+ min_ballon_area = con_area
384
+ ballon_mask = tmp
385
+ if ballon_mask is not None:
386
+ non_text_mask = cv2.bitwise_and(ballon_mask, 255 - mask)
387
+ # cv2.imshow('ballon', ballon_mask)
388
+ # cv2.imshow('non_text', non_text_mask)
389
+ # cv2.imshow('im', img)
390
+ # cv2.imshow('msk', mask)
391
+ # cv2.imshow('canny', cannyed)
392
+ # cv2.waitKey(0)
393
+
394
+ return ballon_mask, non_text_mask
utils/textlines_merge.py ADDED
@@ -0,0 +1,568 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import functools
3
+ from typing import Tuple, List, ClassVar, Union, Any, Dict, Set
4
+ from collections import Counter
5
+ try:
6
+ functools.cached_property
7
+ except AttributeError: # Supports Python versions below 3.8
8
+ from backports.cached_property import cached_property
9
+ functools.cached_property = cached_property
10
+
11
+ import numpy as np
12
+ from shapely.geometry import Polygon, MultiPoint
13
+ import cv2
14
+ import networkx as nx
15
+
16
+
17
+ class BBox(object):
18
+ def __init__(self, x: int, y: int, w: int, h: int, text: str, prob: float, fg_r: int = 0, fg_g: int = 0, fg_b: int = 0, bg_r: int = 0, bg_g: int = 0, bg_b: int = 0):
19
+ self.x = x
20
+ self.y = y
21
+ self.w = w
22
+ self.h = h
23
+ self.text = text
24
+ self.prob = prob
25
+ self.fg_r = fg_r
26
+ self.fg_g = fg_g
27
+ self.fg_b = fg_b
28
+ self.bg_r = bg_r
29
+ self.bg_g = bg_g
30
+ self.bg_b = bg_b
31
+
32
+ def width(self):
33
+ return self.w
34
+
35
+ def height(self):
36
+ return self.h
37
+
38
+ def to_points(self):
39
+ tl, tr, br, bl = np.array([self.x, self.y]), np.array([self.x + self.w, self.y]), np.array([self.x + self.w, self.y+ self.h]), np.array([self.x, self.y + self.h])
40
+ return tl, tr, br, bl
41
+
42
+ @property
43
+ def xywh(self):
44
+ return np.array([self.x, self.y, self.w, self.h], dtype=np.int32)
45
+
46
+
47
+ class Quadrilateral(object):
48
+ """
49
+ Helper for storing textlines that contains various helper functions.
50
+ """
51
+ def __init__(self, pts: np.ndarray, text: str, prob: float, fg_r: int = 0, fg_g: int = 0, fg_b: int = 0, bg_r: int = 0, bg_g: int = 0, bg_b: int = 0):
52
+ self.pts, is_vertical = sort_pnts(pts)
53
+ if is_vertical:
54
+ self.direction = 'v'
55
+ else:
56
+ self.direction = 'h'
57
+ self.text = text
58
+ self.prob = prob
59
+ self.fg_r = fg_r
60
+ self.fg_g = fg_g
61
+ self.fg_b = fg_b
62
+ self.bg_r = bg_r
63
+ self.bg_g = bg_g
64
+ self.bg_b = bg_b
65
+ self.assigned_direction: str = None
66
+ self.textlines: List[Quadrilateral] = []
67
+
68
+ @functools.cached_property
69
+ def structure(self) -> List[np.ndarray]:
70
+ p1 = ((self.pts[0] + self.pts[1]) / 2).astype(int)
71
+ p2 = ((self.pts[2] + self.pts[3]) / 2).astype(int)
72
+ p3 = ((self.pts[1] + self.pts[2]) / 2).astype(int)
73
+ p4 = ((self.pts[3] + self.pts[0]) / 2).astype(int)
74
+ return [p1, p2, p3, p4]
75
+
76
+ @functools.cached_property
77
+ def valid(self) -> bool:
78
+ [l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
79
+ v1 = l1b - l1a
80
+ v2 = l2b - l2a
81
+ unit_vector_1 = v1 / np.linalg.norm(v1)
82
+ unit_vector_2 = v2 / np.linalg.norm(v2)
83
+ dot_product = np.dot(unit_vector_1, unit_vector_2)
84
+ angle = np.arccos(dot_product) * 180 / np.pi
85
+ return abs(angle - 90) < 10
86
+
87
+ @property
88
+ def fg_colors(self):
89
+ return np.array([self.fg_r, self.fg_g, self.fg_b])
90
+
91
+ @property
92
+ def bg_colors(self):
93
+ return np.array([self.bg_r, self.bg_g, self.bg_b])
94
+
95
+ @functools.cached_property
96
+ def aspect_ratio(self) -> float:
97
+ """hor/ver"""
98
+ [l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
99
+ v1 = l1b - l1a
100
+ v2 = l2b - l2a
101
+ return np.linalg.norm(v2) / np.linalg.norm(v1)
102
+
103
+ @functools.cached_property
104
+ def font_size(self) -> float:
105
+ [l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
106
+ v1 = l1b - l1a
107
+ v2 = l2b - l2a
108
+ return min(np.linalg.norm(v2), np.linalg.norm(v1))
109
+
110
+ def width(self) -> int:
111
+ return self.aabb.w
112
+
113
+ def height(self) -> int:
114
+ return self.aabb.h
115
+
116
+ @functools.cached_property
117
+ def xyxy(self):
118
+ return self.aabb.x, self.aabb.y, self.aabb.x + self.aabb.w, self.aabb.y + self.aabb.h
119
+
120
+ def clip(self, width, height):
121
+ self.pts[:, 0] = np.clip(np.round(self.pts[:, 0]), 0, width)
122
+ self.pts[:, 1] = np.clip(np.round(self.pts[:, 1]), 0, height)
123
+
124
+ # @functools.cached_property
125
+ # def points(self):
126
+ # ans = [a.astype(np.float32) for a in self.structure]
127
+ # return [Point(a[0], a[1]) for a in ans]
128
+
129
+ @functools.cached_property
130
+ def aabb(self) -> BBox:
131
+ kq = self.pts
132
+ max_coord = np.max(kq, axis = 0)
133
+ min_coord = np.min(kq, axis = 0)
134
+ return BBox(min_coord[0], min_coord[1], max_coord[0] - min_coord[0], max_coord[1] - min_coord[1], self.text, self.prob, self.fg_r, self.fg_g, self.fg_b, self.bg_r, self.bg_g, self.bg_b)
135
+
136
+ def get_transformed_region(self, img, direction, textheight) -> np.ndarray:
137
+ [l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
138
+ v_vec = l1b - l1a
139
+ h_vec = l2b - l2a
140
+ ratio = np.linalg.norm(v_vec) / np.linalg.norm(h_vec)
141
+
142
+ src_pts = self.pts.astype(np.int64).copy()
143
+ im_h, im_w = img.shape[:2]
144
+
145
+ x1, y1, x2, y2 = src_pts[:, 0].min(), src_pts[:, 1].min(), src_pts[:, 0].max(), src_pts[:, 1].max()
146
+ x1 = np.clip(x1, 0, im_w)
147
+ y1 = np.clip(y1, 0, im_h)
148
+ x2 = np.clip(x2, 0, im_w)
149
+ y2 = np.clip(y2, 0, im_h)
150
+ # cv2.warpPerspective could overflow if image size is too large, better crop it here
151
+ img_croped = img[y1: y2, x1: x2]
152
+
153
+
154
+ src_pts[:, 0] -= x1
155
+ src_pts[:, 1] -= y1
156
+
157
+ self.assigned_direction = direction
158
+ if direction == 'h':
159
+ h = max(int(textheight), 2)
160
+ w = max(int(round(textheight / ratio)), 2)
161
+ dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32)
162
+ M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
163
+ region = cv2.warpPerspective(img_croped, M, (w, h))
164
+ return region
165
+ elif direction == 'v':
166
+ w = max(int(textheight), 2)
167
+ h = max(int(round(textheight * ratio)), 2)
168
+ dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32)
169
+ M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
170
+ region = cv2.warpPerspective(img_croped, M, (w, h))
171
+ region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE)
172
+ return region
173
+
174
+ @functools.cached_property
175
+ def is_axis_aligned(self) -> bool:
176
+ [l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
177
+ v1 = l1b - l1a
178
+ v2 = l2b - l2a
179
+ e1 = np.array([0, 1])
180
+ e2 = np.array([1, 0])
181
+ unit_vector_1 = v1 / np.linalg.norm(v1)
182
+ unit_vector_2 = v2 / np.linalg.norm(v2)
183
+ if abs(np.dot(unit_vector_1, e1)) < 1e-2 or abs(np.dot(unit_vector_1, e2)) < 1e-2:
184
+ return True
185
+ return False
186
+
187
+ @functools.cached_property
188
+ def is_approximate_axis_aligned(self) -> bool:
189
+ [l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
190
+ v1 = l1b - l1a
191
+ v2 = l2b - l2a
192
+ e1 = np.array([0, 1])
193
+ e2 = np.array([1, 0])
194
+ unit_vector_1 = v1 / np.linalg.norm(v1)
195
+ unit_vector_2 = v2 / np.linalg.norm(v2)
196
+ if abs(np.dot(unit_vector_1, e1)) < 0.05 or abs(np.dot(unit_vector_1, e2)) < 0.05 or abs(np.dot(unit_vector_2, e1)) < 0.05 or abs(np.dot(unit_vector_2, e2)) < 0.05:
197
+ return True
198
+ return False
199
+
200
+ @functools.cached_property
201
+ def cosangle(self) -> float:
202
+ [l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
203
+ v1 = l1b - l1a
204
+ e2 = np.array([1, 0])
205
+ unit_vector_1 = v1 / np.linalg.norm(v1)
206
+ return np.dot(unit_vector_1, e2)
207
+
208
+ @functools.cached_property
209
+ def angle(self) -> float:
210
+ return np.fmod(np.arccos(self.cosangle) + np.pi, np.pi)
211
+
212
+ @functools.cached_property
213
+ def centroid(self) -> np.ndarray:
214
+ return np.average(self.pts, axis = 0)
215
+
216
+ def distance_to_point(self, p: np.ndarray) -> float:
217
+ d = 1.0e20
218
+ for i in range(4):
219
+ d = min(d, distance_point_point(p, self.pts[i]))
220
+ d = min(d, distance_point_lineseg(p, self.pts[i], self.pts[(i + 1) % 4]))
221
+ return d
222
+
223
+ @functools.cached_property
224
+ def polygon(self) -> Polygon:
225
+ return MultiPoint([tuple(self.pts[0]), tuple(self.pts[1]), tuple(self.pts[2]), tuple(self.pts[3])]).convex_hull
226
+
227
+ @functools.cached_property
228
+ def area(self) -> float:
229
+ return self.polygon.area
230
+
231
+ def poly_distance(self, other) -> float:
232
+ return self.polygon.distance(other.polygon)
233
+
234
+ def distance(self, other, rho = 0.5) -> float:
235
+ return self.distance_impl(other, rho)# + 1000 * abs(self.angle - other.angle)
236
+
237
+ def distance_impl(self, other, rho = 0.5) -> float:
238
+ # assert self.assigned_direction == other.assigned_direction
239
+ #return gjk_distance(self.points, other.points)
240
+ # b1 = self.aabb
241
+ # b2 = b2.aabb
242
+ # x1, y1, w1, h1 = b1.x, b1.y, b1.w, b1.h
243
+ # x2, y2, w2, h2 = b2.x, b2.y, b2.w, b2.h
244
+ # return rect_distance(x1, y1, x1 + w1, y1 + h1, x2, y2, x2 + w2, y2 + h2)
245
+ pattern = ''
246
+ if self.assigned_direction == 'h':
247
+ pattern = 'h_left'
248
+ else:
249
+ pattern = 'v_top'
250
+ fs = max(self.font_size, other.font_size)
251
+ if self.assigned_direction == 'h':
252
+ poly1 = MultiPoint([tuple(self.pts[0]), tuple(self.pts[3]), tuple(other.pts[0]), tuple(other.pts[3])]).convex_hull
253
+ poly2 = MultiPoint([tuple(self.pts[2]), tuple(self.pts[1]), tuple(other.pts[2]), tuple(other.pts[1])]).convex_hull
254
+ poly3 = MultiPoint([
255
+ tuple(self.structure[0]),
256
+ tuple(self.structure[1]),
257
+ tuple(other.structure[0]),
258
+ tuple(other.structure[1]),
259
+ ]).convex_hull
260
+ dist1 = poly1.area / fs
261
+ dist2 = poly2.area / fs
262
+ dist3 = poly3.area / fs
263
+ if dist1 < fs * rho:
264
+ pattern = 'h_left'
265
+ if dist2 < fs * rho and dist2 < dist1:
266
+ pattern = 'h_right'
267
+ if dist3 < fs * rho and dist3 < dist1 and dist3 < dist2:
268
+ pattern = 'h_middle'
269
+ if pattern == 'h_left':
270
+ return dist(self.pts[0][0], self.pts[0][1], other.pts[0][0], other.pts[0][1])
271
+ elif pattern == 'h_right':
272
+ return dist(self.pts[1][0], self.pts[1][1], other.pts[1][0], other.pts[1][1])
273
+ else:
274
+ return dist(self.structure[0][0], self.structure[0][1], other.structure[0][0], other.structure[0][1])
275
+ else:
276
+ poly1 = MultiPoint([tuple(self.pts[0]), tuple(self.pts[1]), tuple(other.pts[0]), tuple(other.pts[1])]).convex_hull
277
+ poly2 = MultiPoint([tuple(self.pts[2]), tuple(self.pts[3]), tuple(other.pts[2]), tuple(other.pts[3])]).convex_hull
278
+ dist1 = poly1.area / fs
279
+ dist2 = poly2.area / fs
280
+ if dist1 < fs * rho:
281
+ pattern = 'v_top'
282
+ if dist2 < fs * rho and dist2 < dist1:
283
+ pattern = 'v_bottom'
284
+ if pattern == 'v_top':
285
+ return dist(self.pts[0][0], self.pts[0][1], other.pts[0][0], other.pts[0][1])
286
+ else:
287
+ return dist(self.pts[2][0], self.pts[2][1], other.pts[2][0], other.pts[2][1])
288
+
289
+ def copy(self, new_pts: np.ndarray):
290
+ return Quadrilateral(new_pts, self.text, self.prob, *self.fg_colors, *self.bg_colors)
291
+
292
+
293
+ def sort_pnts(pts: np.ndarray):
294
+ '''
295
+ Direction must be provided for sorting.
296
+ The longer structure vector (mean of long side vectors) of input points is used to determine the direction.
297
+ It is reliable enough for text lines but not for blocks.
298
+ '''
299
+
300
+ if isinstance(pts, List):
301
+ pts = np.array(pts)
302
+ assert isinstance(pts, np.ndarray) and pts.shape == (4, 2)
303
+ pairwise_vec = (pts[:, None] - pts[None]).reshape((16, -1))
304
+ pairwise_vec_norm = np.linalg.norm(pairwise_vec, axis=1)
305
+ long_side_ids = np.argsort(pairwise_vec_norm)[[8, 10]]
306
+ long_side_vecs = pairwise_vec[long_side_ids]
307
+ inner_prod = (long_side_vecs[0] * long_side_vecs[1]).sum()
308
+ if inner_prod < 0:
309
+ long_side_vecs[0] = -long_side_vecs[0]
310
+ struc_vec = np.abs(long_side_vecs.mean(axis=0))
311
+ is_vertical = struc_vec[0] <= struc_vec[1]
312
+
313
+ if is_vertical:
314
+ pts = pts[np.argsort(pts[:, 1])]
315
+ pts = pts[[*np.argsort(pts[:2, 0]), *np.argsort(pts[2:, 0])[::-1] + 2]]
316
+ return pts, is_vertical
317
+ else:
318
+ pts = pts[np.argsort(pts[:, 0])]
319
+ pts_sorted = np.zeros_like(pts)
320
+ pts_sorted[[0, 3]] = sorted(pts[[0, 1]], key=lambda x: x[1])
321
+ pts_sorted[[1, 2]] = sorted(pts[[2, 3]], key=lambda x: x[1])
322
+ return pts_sorted, is_vertical
323
+
324
+
325
+ def dist(x1, y1, x2, y2):
326
+ return np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
327
+
328
+
329
+ def distance_point_point(a: np.ndarray, b: np.ndarray) -> float:
330
+ return np.linalg.norm(a - b)
331
+
332
+
333
+ # from https://stackoverflow.com/questions/849211/shortest-distance-between-a-point-and-a-line-segment
334
+ def distance_point_lineseg(p: np.ndarray, p1: np.ndarray, p2: np.ndarray):
335
+ x = p[0]
336
+ y = p[1]
337
+ x1 = p1[0]
338
+ y1 = p1[1]
339
+ x2 = p2[0]
340
+ y2 = p2[1]
341
+ A = x - x1
342
+ B = y - y1
343
+ C = x2 - x1
344
+ D = y2 - y1
345
+
346
+ dot = A * C + B * D
347
+ len_sq = C * C + D * D
348
+ param = -1
349
+ if len_sq != 0:
350
+ param = dot / len_sq
351
+
352
+ if param < 0:
353
+ xx = x1
354
+ yy = y1
355
+ elif param > 1:
356
+ xx = x2
357
+ yy = y2
358
+ else:
359
+ xx = x1 + param * C
360
+ yy = y1 + param * D
361
+
362
+ dx = x - xx
363
+ dy = y - yy
364
+ return np.sqrt(dx * dx + dy * dy)
365
+
366
+
367
+ def quadrilateral_can_merge_region(a: Quadrilateral, b: Quadrilateral, ratio = 1.9, discard_connection_gap = 2, char_gap_tolerance = 0.6, char_gap_tolerance2 = 1.5, font_size_ratio_tol = 1.5, aspect_ratio_tol = 2) -> bool:
368
+ b1 = a.aabb
369
+ b2 = b.aabb
370
+ char_size = min(a.font_size, b.font_size)
371
+ x1, y1, w1, h1 = b1.x, b1.y, b1.w, b1.h
372
+ x2, y2, w2, h2 = b2.x, b2.y, b2.w, b2.h
373
+ # dist = rect_distance(x1, y1, x1 + w1, y1 + h1, x2, y2, x2 + w2, y2 + h2)
374
+ p1 = Polygon(a.pts)
375
+ p2 = Polygon(b.pts)
376
+ dist = p1.distance(p2)
377
+ if dist > discard_connection_gap * char_size:
378
+ return False
379
+ if max(a.font_size, b.font_size) / char_size > font_size_ratio_tol:
380
+ return False
381
+ if a.aspect_ratio > aspect_ratio_tol and b.aspect_ratio < 1. / aspect_ratio_tol:
382
+ return False
383
+ if b.aspect_ratio > aspect_ratio_tol and a.aspect_ratio < 1. / aspect_ratio_tol:
384
+ return False
385
+ a_aa = a.is_approximate_axis_aligned
386
+ b_aa = b.is_approximate_axis_aligned
387
+ if a_aa and b_aa:
388
+ if dist < char_size * char_gap_tolerance:
389
+ if abs(x1 + w1 // 2 - (x2 + w2 // 2)) < char_gap_tolerance2:
390
+ return True
391
+ if w1 > h1 * ratio and h2 > w2 * ratio:
392
+ return False
393
+ if w2 > h2 * ratio and h1 > w1 * ratio:
394
+ return False
395
+ if w1 > h1 * ratio or w2 > h2 * ratio : # h
396
+ return abs(x1 - x2) < char_size * char_gap_tolerance2 or abs(x1 + w1 - (x2 + w2)) < char_size * char_gap_tolerance2
397
+ elif h1 > w1 * ratio or h2 > w2 * ratio : # v
398
+ return abs(y1 - y2) < char_size * char_gap_tolerance2 or abs(y1 + h1 - (y2 + h2)) < char_size * char_gap_tolerance2
399
+ return False
400
+ else:
401
+ return False
402
+ if True:#not a_aa and not b_aa:
403
+ if abs(a.angle - b.angle) < 15 * np.pi / 180:
404
+ fs_a = a.font_size
405
+ fs_b = b.font_size
406
+ fs = min(fs_a, fs_b)
407
+ if a.poly_distance(b) > fs * char_gap_tolerance2:
408
+ return False
409
+ if abs(fs_a - fs_b) / fs > 0.25:
410
+ return False
411
+ return True
412
+ return False
413
+
414
+
415
+ def quadrilateral_can_merge_region_coarse(a: Quadrilateral, b: Quadrilateral, discard_connection_gap = 2, font_size_ratio_tol = 0.7) -> bool:
416
+ if a.assigned_direction != b.assigned_direction:
417
+ return False
418
+ if abs(a.angle - b.angle) > 15 * np.pi / 180:
419
+ return False
420
+ fs_a = a.font_size
421
+ fs_b = b.font_size
422
+ fs = min(fs_a, fs_b)
423
+ if abs(fs_a - fs_b) / fs > font_size_ratio_tol:
424
+ return False
425
+ fs = max(fs_a, fs_b)
426
+ dist = a.poly_distance(b)
427
+ if dist > discard_connection_gap * fs:
428
+ return False
429
+ return True
430
+
431
+
432
+ def split_text_region(
433
+ bboxes: List[Quadrilateral],
434
+ connected_region_indices: Set[int],
435
+ width,
436
+ height,
437
+ gamma = 0.5,
438
+ sigma = 2
439
+ ) -> List[Set[int]]:
440
+
441
+ connected_region_indices = list(connected_region_indices)
442
+
443
+ # case 1
444
+ if len(connected_region_indices) == 1:
445
+ return [set(connected_region_indices)]
446
+
447
+ # case 2
448
+ if len(connected_region_indices) == 2:
449
+ fs1 = bboxes[connected_region_indices[0]].font_size
450
+ fs2 = bboxes[connected_region_indices[1]].font_size
451
+ fs = max(fs1, fs2)
452
+
453
+ # print(bboxes[connected_region_indices[0]].pts, bboxes[connected_region_indices[1]].pts)
454
+ # print(fs, bboxes[connected_region_indices[0]].distance(bboxes[connected_region_indices[1]]), (1 + gamma) * fs)
455
+ # print(bboxes[connected_region_indices[0]].angle, bboxes[connected_region_indices[1]].angle, 4 * np.pi / 180)
456
+
457
+ if bboxes[connected_region_indices[0]].distance(bboxes[connected_region_indices[1]]) < (1 + gamma) * fs \
458
+ and abs(bboxes[connected_region_indices[0]].angle - bboxes[connected_region_indices[1]].angle) < 0.2 * np.pi:
459
+ return [set(connected_region_indices)]
460
+ else:
461
+ return [set([connected_region_indices[0]]), set([connected_region_indices[1]])]
462
+
463
+ # case 3
464
+ G = nx.Graph()
465
+ for idx in connected_region_indices:
466
+ G.add_node(idx)
467
+ for (u, v) in itertools.combinations(connected_region_indices, 2):
468
+ G.add_edge(u, v, weight=bboxes[u].distance(bboxes[v]))
469
+ # Get distances from neighbouring bboxes
470
+ edges = nx.algorithms.tree.minimum_spanning_edges(G, algorithm='kruskal', data=True)
471
+ edges = sorted(edges, key=lambda a: a[2]['weight'], reverse=True)
472
+ distances_sorted = [a[2]['weight'] for a in edges]
473
+ fontsize = np.mean([bboxes[idx].font_size for idx in connected_region_indices])
474
+ distances_std = np.std(distances_sorted)
475
+ distances_mean = np.mean(distances_sorted)
476
+ std_threshold = max(0.3 * fontsize + 5, 5)
477
+
478
+ b1, b2 = bboxes[edges[0][0]], bboxes[edges[0][1]]
479
+ max_poly_distance = Polygon(b1.pts).distance(Polygon(b2.pts))
480
+ max_centroid_alignment = min(abs(b1.centroid[0] - b2.centroid[0]), abs(b1.centroid[1] - b2.centroid[1]))
481
+
482
+ # print(edges)
483
+ # print(f'std: {distances_std} < thrshold: {std_threshold}, mean: {distances_mean}')
484
+ # print(f'{distances_sorted[0]} <= {distances_mean + distances_std * sigma}' \
485
+ # f' or {distances_sorted[0]} <= {fontsize * (1 + gamma)}' \
486
+ # f' or {distances_sorted[0] - distances_sorted[1]} < {distances_std * sigma}')
487
+
488
+ if (distances_sorted[0] <= distances_mean + distances_std * sigma \
489
+ or distances_sorted[0] <= fontsize * (1 + gamma)) \
490
+ and (distances_std < std_threshold \
491
+ or max_poly_distance == 0 and max_centroid_alignment < 5):
492
+ return [set(connected_region_indices)]
493
+ else:
494
+ # (split_u, split_v, _) = edges[0]
495
+ # print(f'split between "{bboxes[split_u].pts}", "{bboxes[split_v].pts}"')
496
+ G = nx.Graph()
497
+ for idx in connected_region_indices:
498
+ G.add_node(idx)
499
+ # Split out the most deviating bbox
500
+ for edge in edges[1:]:
501
+ G.add_edge(edge[0], edge[1])
502
+ ans = []
503
+ for node_set in nx.algorithms.components.connected_components(G):
504
+ ans.extend(split_text_region(bboxes, node_set, width, height))
505
+ return ans
506
+
507
+
508
+
509
+ def merge_bboxes_text_region(bboxes: List[Quadrilateral], width, height):
510
+
511
+ # step 1: divide into multiple text region candidates
512
+ G = nx.Graph()
513
+ for i, box in enumerate(bboxes):
514
+ G.add_node(i, box=box)
515
+
516
+ for ((u, ubox), (v, vbox)) in itertools.combinations(enumerate(bboxes), 2):
517
+ # if quadrilateral_can_merge_region_coarse(ubox, vbox):
518
+ if quadrilateral_can_merge_region(ubox, vbox, aspect_ratio_tol=1.3, font_size_ratio_tol=2,
519
+ char_gap_tolerance=1, char_gap_tolerance2=3):
520
+ G.add_edge(u, v)
521
+
522
+ # step 2: postprocess - further split each region
523
+ region_indices: List[Set[int]] = []
524
+ for node_set in nx.algorithms.components.connected_components(G):
525
+ region_indices.extend(split_text_region(bboxes, node_set, width, height))
526
+
527
+ # step 3: return regions
528
+ for node_set in region_indices:
529
+ # for node_set in nx.algorithms.components.connected_components(G):
530
+ nodes = list(node_set)
531
+ txtlns: List[Quadrilateral] = np.array(bboxes)[nodes]
532
+
533
+ # calculate average fg and bg color
534
+ fg_r = round(np.mean([box.fg_r for box in txtlns]))
535
+ fg_g = round(np.mean([box.fg_g for box in txtlns]))
536
+ fg_b = round(np.mean([box.fg_b for box in txtlns]))
537
+ bg_r = round(np.mean([box.bg_r for box in txtlns]))
538
+ bg_g = round(np.mean([box.bg_g for box in txtlns]))
539
+ bg_b = round(np.mean([box.bg_b for box in txtlns]))
540
+
541
+ # majority vote for direction
542
+ dirs = [box.direction for box in txtlns]
543
+ majority_dir_top_2 = Counter(dirs).most_common(2)
544
+ if len(majority_dir_top_2) == 1 :
545
+ majority_dir = majority_dir_top_2[0][0]
546
+ elif majority_dir_top_2[0][1] == majority_dir_top_2[1][1] : # if top 2 have the same counts
547
+ max_aspect_ratio = -100
548
+ for box in txtlns :
549
+ if box.aspect_ratio > max_aspect_ratio :
550
+ max_aspect_ratio = box.aspect_ratio
551
+ majority_dir = box.direction
552
+ if 1.0 / box.aspect_ratio > max_aspect_ratio :
553
+ max_aspect_ratio = 1.0 / box.aspect_ratio
554
+ majority_dir = box.direction
555
+ else :
556
+ majority_dir = majority_dir_top_2[0][0]
557
+
558
+ # sort textlines
559
+ if majority_dir == 'h':
560
+ nodes = sorted(nodes, key=lambda x: bboxes[x].centroid[1])
561
+ elif majority_dir == 'v':
562
+ nodes = sorted(nodes, key=lambda x: -bboxes[x].centroid[0])
563
+ txtlns = np.array(bboxes)[nodes]
564
+
565
+ # yield overall bbox and sorted indices
566
+ yield txtlns, (fg_r, fg_g, fg_b), (bg_r, bg_g, bg_b)
567
+
568
+
utils/watermark_utils.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ from PIL import Image
3
+
4
+ def apply_watermark_to_pil_image(img_pil: Image.Image, watermark_path: str, opacity: float = 0.7) -> Image.Image:
5
+ """
6
+ Apply watermark to a PIL image
7
+
8
+ Args:
9
+ img_pil (Image.Image): Source PIL image
10
+ watermark_path (str): Path to watermark image
11
+ opacity (float): Watermark opacity (0.0 - 1.0)
12
+
13
+ Returns:
14
+ Image.Image: Watermarked PIL image
15
+ """
16
+ if not osp.exists(watermark_path):
17
+ return img_pil
18
+
19
+ try:
20
+ watermark = Image.open(watermark_path)
21
+ except Exception:
22
+ return img_pil
23
+
24
+ # Ensure images are in RGBA mode
25
+ if img_pil.mode != 'RGBA':
26
+ img_pil = img_pil.convert('RGBA')
27
+ if watermark.mode != 'RGBA':
28
+ watermark = watermark.convert('RGBA')
29
+
30
+ # Fixed watermark size (adjust as needed)
31
+ WATERMARK_FIXED_WIDTH = 418
32
+ WATERMARK_FIXED_HEIGHT = 120
33
+
34
+ # Resize watermark
35
+ watermark = watermark.resize((WATERMARK_FIXED_WIDTH, WATERMARK_FIXED_HEIGHT), Image.LANCZOS)
36
+
37
+ # Apply opacity
38
+ if opacity < 1.0:
39
+ alpha = watermark.split()[3]
40
+ alpha = alpha.point(lambda p: p * opacity)
41
+ watermark.putalpha(alpha)
42
+
43
+ # Get image dimensions
44
+ img_width, img_height = img_pil.size
45
+
46
+ # Create transparent layer for watermarks
47
+ wm_layer = Image.new('RGBA', img_pil.size, (0, 0, 0, 0))
48
+
49
+ # Calculate watermark positions (bottom to top)
50
+ initial_y = img_height - watermark.height - 10 # 10px from bottom
51
+ x_position = 10 # 10px from left
52
+
53
+ # Repeat watermark vertically
54
+ current_y = initial_y
55
+ while current_y > -watermark.height:
56
+ if current_y < 0:
57
+ # Crop watermark if it goes beyond top boundary
58
+ crop_height = watermark.height + current_y
59
+ if crop_height > 0:
60
+ partial_wm = watermark.crop((0, -current_y, watermark.width, watermark.height))
61
+ wm_layer.paste(partial_wm, (x_position, 0), partial_wm)
62
+ else:
63
+ wm_layer.paste(watermark, (x_position, current_y), watermark)
64
+
65
+ current_y -= 8000 # Vertical spacing (adjust as needed)
66
+
67
+ # Composite original image with watermark layer
68
+ return Image.alpha_composite(img_pil, wm_layer)
utils/zluda_config.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ # 检测是否包含 ZLUDA 标记
5
+ def zluda_available(device_name):
6
+ return "[ZLUDA]" in device_name
7
+
8
+
9
+ # 关闭 ZLUDA Cudnn 支持 防止错误
10
+ def enable_zluda_config():
11
+ if hasattr(torch, 'cuda') and torch.cuda.is_available():
12
+ device_name = torch.cuda.get_device_name(0)
13
+ print('Device name: ', device_name)
14
+ print('Cuda is available: ', torch.cuda.is_available())
15
+ print('Cuda version: ', torch.version.cuda)
16
+ print('ZLUDA is available: ', zluda_available(device_name))
17
+
18
+ if zluda_available(device_name):
19
+ torch.backends.cudnn.enabled = False
20
+ cuda_attr = torch.backends.cuda
21
+ if hasattr(cuda_attr, 'enable_flash_sdp'):
22
+ torch.backends.cuda.enable_flash_sdp(False)
23
+ print('Cuda enable flash sdp: ', False)
24
+ if hasattr(cuda_attr, 'enable_math_sdp'):
25
+ torch.backends.cuda.enable_math_sdp(True)
26
+ print('Cuda enable math sdp: ', True)
27
+ if hasattr(cuda_attr, 'enable_mem_efficient_sdp'):
28
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
29
+ print('Cuda enable mem efficient sdp: ', False)
30
+ if hasattr(cuda_attr, 'enable_cudnn_sdp'):
31
+ torch.backends.cuda.enable_cudnn_sdp(False)
32
+ print('Cuda enable cudnn sdp: ', False)