Upload 48 files
Browse files- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/config.cpython-310.pyc +0 -0
- utils/__pycache__/download_util.cpython-310.pyc +0 -0
- utils/__pycache__/exceptions.cpython-310.pyc +0 -0
- utils/__pycache__/fontformat.cpython-310.pyc +0 -0
- utils/__pycache__/imgproc_utils.cpython-310.pyc +0 -0
- utils/__pycache__/io_utils.cpython-310.pyc +0 -0
- utils/__pycache__/logger.cpython-310.pyc +0 -0
- utils/__pycache__/message.cpython-310.pyc +0 -0
- utils/__pycache__/proj_imgtrans.cpython-310.pyc +0 -0
- utils/__pycache__/registry.cpython-310.pyc +0 -0
- utils/__pycache__/shared.cpython-310.pyc +0 -0
- utils/__pycache__/shared.cpython-39.pyc +0 -0
- utils/__pycache__/split_text_region.cpython-310.pyc +0 -0
- utils/__pycache__/stroke_width_calculator.cpython-310.pyc +0 -0
- utils/__pycache__/structures.cpython-310.pyc +0 -0
- utils/__pycache__/text_layout.cpython-310.pyc +0 -0
- utils/__pycache__/text_processing.cpython-310.pyc +0 -0
- utils/__pycache__/textblock.cpython-310.pyc +0 -0
- utils/__pycache__/textblock_mask.cpython-310.pyc +0 -0
- utils/__pycache__/textlines_merge.cpython-310.pyc +0 -0
- utils/__pycache__/watermark_utils.cpython-310.pyc +0 -0
- utils/__pycache__/zluda_config.cpython-310.pyc +0 -0
- utils/appinfo.py +2 -0
- utils/config.py +287 -0
- utils/download_util.py +371 -0
- utils/exceptions.py +20 -0
- utils/fontformat.py +136 -0
- utils/imgproc_utils.py +413 -0
- utils/io_utils.py +243 -0
- utils/logger.py +99 -0
- utils/message.py +67 -0
- utils/package.py +289 -0
- utils/proj_imgtrans.py +720 -0
- utils/registry.py +272 -0
- utils/shared.py +160 -0
- utils/split_text_region.py +386 -0
- utils/stroke_width_calculator.py +113 -0
- utils/structures.py +84 -0
- utils/text_layout.py +477 -0
- utils/text_processing.py +237 -0
- utils/textblock.py +908 -0
- utils/textblock_mask.py +394 -0
- utils/textlines_merge.py +568 -0
- utils/watermark_utils.py +68 -0
- utils/zluda_config.py +32 -0
utils/__init__.py
ADDED
|
File without changes
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (133 Bytes). View file
|
|
|
utils/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (11.2 kB). View file
|
|
|
utils/__pycache__/download_util.cpython-310.pyc
ADDED
|
Binary file (9.14 kB). View file
|
|
|
utils/__pycache__/exceptions.cpython-310.pyc
ADDED
|
Binary file (1.12 kB). View file
|
|
|
utils/__pycache__/fontformat.cpython-310.pyc
ADDED
|
Binary file (5.73 kB). View file
|
|
|
utils/__pycache__/imgproc_utils.cpython-310.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
utils/__pycache__/io_utils.cpython-310.pyc
ADDED
|
Binary file (7.77 kB). View file
|
|
|
utils/__pycache__/logger.cpython-310.pyc
ADDED
|
Binary file (2.96 kB). View file
|
|
|
utils/__pycache__/message.cpython-310.pyc
ADDED
|
Binary file (2.23 kB). View file
|
|
|
utils/__pycache__/proj_imgtrans.cpython-310.pyc
ADDED
|
Binary file (22.1 kB). View file
|
|
|
utils/__pycache__/registry.cpython-310.pyc
ADDED
|
Binary file (8.16 kB). View file
|
|
|
utils/__pycache__/shared.cpython-310.pyc
ADDED
|
Binary file (4.14 kB). View file
|
|
|
utils/__pycache__/shared.cpython-39.pyc
ADDED
|
Binary file (4.08 kB). View file
|
|
|
utils/__pycache__/split_text_region.cpython-310.pyc
ADDED
|
Binary file (9.84 kB). View file
|
|
|
utils/__pycache__/stroke_width_calculator.cpython-310.pyc
ADDED
|
Binary file (3.49 kB). View file
|
|
|
utils/__pycache__/structures.cpython-310.pyc
ADDED
|
Binary file (2.84 kB). View file
|
|
|
utils/__pycache__/text_layout.cpython-310.pyc
ADDED
|
Binary file (9.26 kB). View file
|
|
|
utils/__pycache__/text_processing.cpython-310.pyc
ADDED
|
Binary file (5.42 kB). View file
|
|
|
utils/__pycache__/textblock.cpython-310.pyc
ADDED
|
Binary file (26.8 kB). View file
|
|
|
utils/__pycache__/textblock_mask.cpython-310.pyc
ADDED
|
Binary file (11.1 kB). View file
|
|
|
utils/__pycache__/textlines_merge.cpython-310.pyc
ADDED
|
Binary file (19 kB). View file
|
|
|
utils/__pycache__/watermark_utils.cpython-310.pyc
ADDED
|
Binary file (1.65 kB). View file
|
|
|
utils/__pycache__/zluda_config.cpython-310.pyc
ADDED
|
Binary file (1.13 kB). View file
|
|
|
utils/appinfo.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
branch = 'dev'
|
| 2 |
+
version = '1.4.0'
|
utils/config.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json, os, traceback
|
| 2 |
+
import os.path as osp
|
| 3 |
+
import copy
|
| 4 |
+
|
| 5 |
+
from . import shared
|
| 6 |
+
from .fontformat import FontFormat
|
| 7 |
+
from .structures import List, Dict, Config, field, nested_dataclass
|
| 8 |
+
from .logger import logger as LOGGER
|
| 9 |
+
from .io_utils import json_dump_nested_obj, np, serialize_np
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@nested_dataclass
|
| 13 |
+
class ModuleConfig(Config):
|
| 14 |
+
textdetector: str = 'ctd'
|
| 15 |
+
ocr: str = "mit48px"
|
| 16 |
+
inpainter: str = 'lama_large_512px'
|
| 17 |
+
translator: str = "google"
|
| 18 |
+
enable_detect: bool = True
|
| 19 |
+
keep_exist_textlines: bool = False
|
| 20 |
+
enable_ocr: bool = True
|
| 21 |
+
enable_translate: bool = True
|
| 22 |
+
enable_inpaint: bool = True
|
| 23 |
+
textdetector_params: Dict = field(default_factory=lambda: dict())
|
| 24 |
+
ocr_params: Dict = field(default_factory=lambda: dict())
|
| 25 |
+
translator_params: Dict = field(default_factory=lambda: dict())
|
| 26 |
+
inpainter_params: Dict = field(default_factory=lambda: dict())
|
| 27 |
+
translate_source: str = '日本語'
|
| 28 |
+
translate_target: str = '简体中文'
|
| 29 |
+
check_need_inpaint: bool = True
|
| 30 |
+
load_model_on_demand: bool = False
|
| 31 |
+
empty_runcache: bool = False
|
| 32 |
+
|
| 33 |
+
def get_params(self, module_key: str, for_saving=False) -> dict:
|
| 34 |
+
d = self[module_key + '_params']
|
| 35 |
+
if not for_saving:
|
| 36 |
+
return d
|
| 37 |
+
sd = {}
|
| 38 |
+
for module_key, module_params in d.items():
|
| 39 |
+
if module_params is None:
|
| 40 |
+
continue
|
| 41 |
+
saving_module_params = {}
|
| 42 |
+
sd[module_key] = saving_module_params
|
| 43 |
+
for pk, pv in module_params.items():
|
| 44 |
+
if pk in {'description'}:
|
| 45 |
+
continue
|
| 46 |
+
if isinstance(pv, dict):
|
| 47 |
+
pv = pv['value']
|
| 48 |
+
saving_module_params[pk] = pv
|
| 49 |
+
return sd
|
| 50 |
+
|
| 51 |
+
def get_saving_params(self, to_dict=True):
|
| 52 |
+
params = copy.copy(self)
|
| 53 |
+
params.ocr_params = self.get_params('ocr', for_saving=True)
|
| 54 |
+
params.inpainter_params = self.get_params('inpainter', for_saving=True)
|
| 55 |
+
params.textdetector_params = self.get_params('textdetector', for_saving=True)
|
| 56 |
+
params.translator_params = self.get_params('translator', for_saving=True)
|
| 57 |
+
if to_dict:
|
| 58 |
+
return params.__dict__
|
| 59 |
+
return params
|
| 60 |
+
|
| 61 |
+
def stage_enabled(self, idx: int):
|
| 62 |
+
if idx == 0:
|
| 63 |
+
return self.enable_detect
|
| 64 |
+
elif idx == 1:
|
| 65 |
+
return self.enable_ocr
|
| 66 |
+
elif idx == 2:
|
| 67 |
+
return self.enable_translate
|
| 68 |
+
elif idx == 3:
|
| 69 |
+
return self.enable_inpaint
|
| 70 |
+
else:
|
| 71 |
+
raise Exception(f'not supported stage idx: {idx}')
|
| 72 |
+
|
| 73 |
+
def all_stages_disabled(self):
|
| 74 |
+
return (self.enable_detect or self.enable_ocr or self.enable_translate or self.enable_inpaint) is False
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@nested_dataclass
|
| 78 |
+
class DrawPanelConfig(Config):
|
| 79 |
+
pentool_color: List = field(default_factory=lambda: [0, 0, 0])
|
| 80 |
+
pentool_width: float = 30.
|
| 81 |
+
pentool_shape: int = 0
|
| 82 |
+
inpainter_width: float = 30.
|
| 83 |
+
inpainter_shape: int = 0
|
| 84 |
+
current_tool: int = 0
|
| 85 |
+
rectool_auto: bool = False
|
| 86 |
+
rectool_method: int = 0
|
| 87 |
+
recttool_dilate_ksize: int = 0
|
| 88 |
+
|
| 89 |
+
@nested_dataclass
|
| 90 |
+
class ProgramConfig(Config):
|
| 91 |
+
|
| 92 |
+
module: ModuleConfig = field(default_factory=lambda: ModuleConfig())
|
| 93 |
+
drawpanel: DrawPanelConfig = field(default_factory=lambda: DrawPanelConfig())
|
| 94 |
+
global_fontformat: FontFormat = field(default_factory=lambda: FontFormat())
|
| 95 |
+
recent_proj_list: List = field(default_factory=lambda: list())
|
| 96 |
+
show_page_list: bool = False
|
| 97 |
+
imgtrans_paintmode: bool = False
|
| 98 |
+
imgtrans_textedit: bool = True
|
| 99 |
+
imgtrans_textblock: bool = True
|
| 100 |
+
auto_watermark: bool = False
|
| 101 |
+
mask_transparency: float = 0.
|
| 102 |
+
original_transparency: float = 0.
|
| 103 |
+
open_recent_on_startup: bool = True
|
| 104 |
+
let_fntsize_flag: int = 0
|
| 105 |
+
let_fntstroke_flag: int = 0
|
| 106 |
+
let_fntcolor_flag: int = 0
|
| 107 |
+
let_fnt_scolor_flag: int = 0
|
| 108 |
+
let_fnteffect_flag: int = 1
|
| 109 |
+
let_alignment_flag: int = 0
|
| 110 |
+
let_writing_mode_flag: int = 0
|
| 111 |
+
let_family_flag: int = 0
|
| 112 |
+
let_autolayout_flag: bool = True
|
| 113 |
+
let_uppercase_flag: bool = True
|
| 114 |
+
let_show_only_custom_fonts_flag: bool = False
|
| 115 |
+
let_textstyle_indep_flag: bool = False
|
| 116 |
+
text_styles_path: str = osp.join(shared.DEFAULT_TEXTSTYLE_DIR, 'default.json')
|
| 117 |
+
fsearch_case: bool = False
|
| 118 |
+
fsearch_whole_word: bool = False
|
| 119 |
+
fsearch_regex: bool = False
|
| 120 |
+
fsearch_range: int = 0
|
| 121 |
+
gsearch_case: bool = False
|
| 122 |
+
gsearch_whole_word: bool = False
|
| 123 |
+
gsearch_regex: bool = False
|
| 124 |
+
gsearch_range: int = 0
|
| 125 |
+
darkmode: bool = False
|
| 126 |
+
textselect_mini_menu: bool = True
|
| 127 |
+
fold_textarea: bool = False
|
| 128 |
+
show_source_text: bool = True
|
| 129 |
+
show_trans_text: bool = True
|
| 130 |
+
saladict_shortcut: str = "Alt+S"
|
| 131 |
+
search_url: str = "https://www.google.com/search?q="
|
| 132 |
+
ocr_sublist: List = field(default_factory=lambda: list())
|
| 133 |
+
restore_ocr_empty: bool = False
|
| 134 |
+
pre_mt_sublist: List = field(default_factory=lambda: list())
|
| 135 |
+
mt_sublist: List = field(default_factory=lambda: list())
|
| 136 |
+
display_lang: str = field(default_factory=lambda: shared.DEFAULT_DISPLAY_LANG) # to always apply shared.DEFAULT_DISPLAY_LANG
|
| 137 |
+
imgsave_quality: int = 100
|
| 138 |
+
imgsave_ext: str = '.png'
|
| 139 |
+
intermediate_imgsave_ext: str = '.png'
|
| 140 |
+
show_text_style_preset: bool = True
|
| 141 |
+
expand_tstyle_panel: bool = True
|
| 142 |
+
show_text_effect_panel: bool = True
|
| 143 |
+
expand_teffect_panel: bool = True
|
| 144 |
+
text_advanced_format_panel: bool = True
|
| 145 |
+
expand_tadvanced_panel: bool = True
|
| 146 |
+
# Watermark settings
|
| 147 |
+
watermark_enabled: bool = False
|
| 148 |
+
watermark_path: str = ''""
|
| 149 |
+
watermark_opacity: int = 0.7
|
| 150 |
+
@staticmethod
|
| 151 |
+
def load(cfg_path: str):
|
| 152 |
+
|
| 153 |
+
with open(cfg_path, 'r', encoding='utf8') as f:
|
| 154 |
+
config_dict = json.loads(f.read())
|
| 155 |
+
|
| 156 |
+
# for backward compatibility
|
| 157 |
+
if 'dl' in config_dict:
|
| 158 |
+
dl = config_dict.pop('dl')
|
| 159 |
+
if not 'module' in config_dict:
|
| 160 |
+
if 'textdetector_setup_params' in dl:
|
| 161 |
+
textdetector_params = dl.pop('textdetector_setup_params')
|
| 162 |
+
dl['textdetector_params'] = textdetector_params
|
| 163 |
+
if 'inpainter_setup_params' in dl:
|
| 164 |
+
inpainter_params = dl.pop('inpainter_setup_params')
|
| 165 |
+
dl['inpainter_params'] = inpainter_params
|
| 166 |
+
if 'ocr_setup_params' in dl:
|
| 167 |
+
ocr_params = dl.pop('ocr_setup_params')
|
| 168 |
+
dl['ocr_params'] = ocr_params
|
| 169 |
+
if 'translator_setup_params' in dl:
|
| 170 |
+
translator_params = dl.pop('translator_setup_params')
|
| 171 |
+
dl['translator_params'] = translator_params
|
| 172 |
+
config_dict['module'] = dl
|
| 173 |
+
|
| 174 |
+
if 'module' in config_dict:
|
| 175 |
+
module_cfg = config_dict['module']
|
| 176 |
+
trans_params = module_cfg['translator_params']
|
| 177 |
+
repl_pairs = {'baidu': 'Baidu', 'caiyun': 'Caiyun', 'chatgpt': 'ChatGPT', 'Deepl': 'DeepL', 'papago': 'Papago'}
|
| 178 |
+
for k, i in repl_pairs.items():
|
| 179 |
+
if k in trans_params:
|
| 180 |
+
trans_params[i] = trans_params.pop(k)
|
| 181 |
+
if module_cfg['translator'] in repl_pairs:
|
| 182 |
+
module_cfg['translator'] = repl_pairs[module_cfg['translator']]
|
| 183 |
+
|
| 184 |
+
return ProgramConfig(**config_dict)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
pcfg: ProgramConfig = None
|
| 188 |
+
text_styles: List[FontFormat] = []
|
| 189 |
+
active_format: FontFormat = None
|
| 190 |
+
|
| 191 |
+
def load_textstyle_from(p: str, raise_exception = False):
|
| 192 |
+
|
| 193 |
+
if not osp.exists(p):
|
| 194 |
+
LOGGER.warning(f'Text style {p} does not exist.')
|
| 195 |
+
return
|
| 196 |
+
|
| 197 |
+
try:
|
| 198 |
+
with open(p, 'r', encoding='utf8') as f:
|
| 199 |
+
style_list = json.loads(f.read())
|
| 200 |
+
styles_loaded = []
|
| 201 |
+
for style in style_list:
|
| 202 |
+
try:
|
| 203 |
+
styles_loaded.append(FontFormat(**style))
|
| 204 |
+
except Exception as e:
|
| 205 |
+
LOGGER.warning(f'Skip invalid text style: {style}')
|
| 206 |
+
except Exception as e:
|
| 207 |
+
LOGGER.error(f'Failed to load text style from {p}: {e}')
|
| 208 |
+
if raise_exception:
|
| 209 |
+
raise e
|
| 210 |
+
return
|
| 211 |
+
|
| 212 |
+
global text_styles, pcfg
|
| 213 |
+
if len(text_styles) > 0:
|
| 214 |
+
text_styles.clear()
|
| 215 |
+
text_styles.extend(styles_loaded)
|
| 216 |
+
pcfg.text_styles_path = p
|
| 217 |
+
|
| 218 |
+
def load_config(config_path: str = shared.CONFIG_PATH):
|
| 219 |
+
if config_path != shared.CONFIG_PATH:
|
| 220 |
+
shared.CONFIG_PATH = config_path
|
| 221 |
+
LOGGER.info(f'Using specified config file at {shared.CONFIG_PATH}')
|
| 222 |
+
|
| 223 |
+
if osp.exists(shared.CONFIG_PATH):
|
| 224 |
+
try:
|
| 225 |
+
config = ProgramConfig.load(shared.CONFIG_PATH)
|
| 226 |
+
except Exception as e:
|
| 227 |
+
LOGGER.exception(e)
|
| 228 |
+
LOGGER.warning("Failed to load config file, using default config")
|
| 229 |
+
config = ProgramConfig()
|
| 230 |
+
else:
|
| 231 |
+
LOGGER.info(f'{shared.CONFIG_PATH} does not exist, new config file will be created.')
|
| 232 |
+
config = ProgramConfig()
|
| 233 |
+
|
| 234 |
+
global pcfg
|
| 235 |
+
pcfg = config
|
| 236 |
+
|
| 237 |
+
p = pcfg.text_styles_path
|
| 238 |
+
if not osp.exists(pcfg.text_styles_path):
|
| 239 |
+
dp = osp.join(shared.DEFAULT_TEXTSTYLE_DIR, 'default.json')
|
| 240 |
+
if p != dp and osp.exists(dp):
|
| 241 |
+
p = dp
|
| 242 |
+
LOGGER.warning(f'Text style {p} does not exist, use the default from {dp}.')
|
| 243 |
+
else:
|
| 244 |
+
with open(dp, 'w', encoding='utf8') as f:
|
| 245 |
+
f.write(json.dumps([], ensure_ascii=False))
|
| 246 |
+
LOGGER.info(f'New text style file created at {dp}.')
|
| 247 |
+
load_textstyle_from(p)
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def json_dump_program_config(obj, **kwargs):
|
| 251 |
+
def _default(obj):
|
| 252 |
+
if isinstance(obj, (np.ndarray, np.ScalarType)):
|
| 253 |
+
return serialize_np(obj)
|
| 254 |
+
elif isinstance(obj, ModuleConfig):
|
| 255 |
+
return obj.get_saving_params()
|
| 256 |
+
return obj.__dict__
|
| 257 |
+
return json.dumps(obj, default=lambda o: _default(o), ensure_ascii=False, **kwargs)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def save_config():
|
| 261 |
+
global pcfg
|
| 262 |
+
try:
|
| 263 |
+
with open(shared.CONFIG_PATH, 'w', encoding='utf8') as f:
|
| 264 |
+
f.write(json_dump_program_config(pcfg))
|
| 265 |
+
LOGGER.info('Config saved')
|
| 266 |
+
return True
|
| 267 |
+
except Exception as e:
|
| 268 |
+
LOGGER.error(f'Failed save config to {shared.CONFIG_PATH}: {e}')
|
| 269 |
+
LOGGER.error(traceback.format_exc())
|
| 270 |
+
return False
|
| 271 |
+
|
| 272 |
+
def save_text_styles(raise_exception = False):
|
| 273 |
+
global pcfg, text_styles
|
| 274 |
+
try:
|
| 275 |
+
style_dir = osp.dirname(pcfg.text_styles_path)
|
| 276 |
+
if not osp.exists(style_dir):
|
| 277 |
+
os.makedirs(style_dir)
|
| 278 |
+
with open(pcfg.text_styles_path, 'w', encoding='utf8') as f:
|
| 279 |
+
f.write(json_dump_nested_obj(text_styles))
|
| 280 |
+
LOGGER.info('Text style saved')
|
| 281 |
+
return True
|
| 282 |
+
except Exception as e:
|
| 283 |
+
LOGGER.error(f'Failed save text style to {pcfg.text_styles_path}: {e}')
|
| 284 |
+
LOGGER.error(traceback.format_exc())
|
| 285 |
+
if raise_exception:
|
| 286 |
+
raise e
|
| 287 |
+
return False
|
utils/download_util.py
ADDED
|
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import requests
|
| 4 |
+
import traceback
|
| 5 |
+
import re
|
| 6 |
+
import sys
|
| 7 |
+
import shutil
|
| 8 |
+
import os.path as osp
|
| 9 |
+
from typing import List, Union
|
| 10 |
+
import hashlib
|
| 11 |
+
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from urllib.parse import urlparse
|
| 14 |
+
from torch.hub import download_url_to_file as _torchhub_download_url_to_file, get_dir
|
| 15 |
+
import requests
|
| 16 |
+
import tqdm
|
| 17 |
+
from py7zr import pack_7zarchive, unpack_7zarchive
|
| 18 |
+
import ssl
|
| 19 |
+
|
| 20 |
+
from . import shared
|
| 21 |
+
from .logger import logger as LOGGER
|
| 22 |
+
|
| 23 |
+
shutil.register_archive_format('7zip', pack_7zarchive, description='7zip archive')
|
| 24 |
+
shutil.register_unpack_format('7zip', ['.7z'], unpack_7zarchive)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def calculate_sha256(filename):
|
| 28 |
+
hash_sha256 = hashlib.sha256()
|
| 29 |
+
blksize = 1024 * 1024
|
| 30 |
+
|
| 31 |
+
with open(filename, "rb") as f:
|
| 32 |
+
for chunk in iter(lambda: f.read(blksize), b""):
|
| 33 |
+
hash_sha256.update(chunk)
|
| 34 |
+
|
| 35 |
+
return hash_sha256.hexdigest().lower()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def sizeof_fmt(size, suffix='B'):
|
| 39 |
+
"""Get human readable file size.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
size (int): File size.
|
| 43 |
+
suffix (str): Suffix. Default: 'B'.
|
| 44 |
+
|
| 45 |
+
Return:
|
| 46 |
+
str: Formatted file size.
|
| 47 |
+
"""
|
| 48 |
+
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
| 49 |
+
if abs(size) < 1024.0:
|
| 50 |
+
return f'{size:3.1f} {unit}{suffix}'
|
| 51 |
+
size /= 1024.0
|
| 52 |
+
return f'{size:3.1f} Y{suffix}'
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def download_file_from_google_drive(file_id, save_path):
|
| 56 |
+
"""Download files from google drive.
|
| 57 |
+
|
| 58 |
+
Ref:
|
| 59 |
+
https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
file_id (str): File id.
|
| 63 |
+
save_path (str): Save path.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
session = requests.Session()
|
| 67 |
+
URL = 'https://docs.google.com/uc?export=download'
|
| 68 |
+
params = {'id': file_id, 'confirm': 't'} # https://stackoverflow.com/a/73893665/17671327
|
| 69 |
+
|
| 70 |
+
response = session.get(URL, params=params, stream=True)
|
| 71 |
+
token = get_confirm_token(response)
|
| 72 |
+
if token:
|
| 73 |
+
params['confirm'] = token
|
| 74 |
+
response = session.get(URL, params=params, stream=True)
|
| 75 |
+
|
| 76 |
+
# get file size
|
| 77 |
+
response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
|
| 78 |
+
if 'Content-Range' in response_file_size.headers:
|
| 79 |
+
file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
|
| 80 |
+
else:
|
| 81 |
+
file_size = None
|
| 82 |
+
|
| 83 |
+
save_response_content(response, save_path, file_size)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def get_confirm_token(response):
|
| 87 |
+
for key, value in response.cookies.items():
|
| 88 |
+
if key.startswith('download_warning'):
|
| 89 |
+
return value
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def save_response_content(response, destination, file_size=None, chunk_size=32768):
|
| 94 |
+
if file_size is not None:
|
| 95 |
+
pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
|
| 96 |
+
|
| 97 |
+
readable_file_size = sizeof_fmt(file_size)
|
| 98 |
+
else:
|
| 99 |
+
pbar = None
|
| 100 |
+
|
| 101 |
+
with open(destination, 'wb') as f:
|
| 102 |
+
downloaded_size = 0
|
| 103 |
+
for chunk in response.iter_content(chunk_size):
|
| 104 |
+
downloaded_size += chunk_size
|
| 105 |
+
if pbar is not None:
|
| 106 |
+
pbar.update(1)
|
| 107 |
+
pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
|
| 108 |
+
if chunk: # filter out keep-alive new chunks
|
| 109 |
+
f.write(chunk)
|
| 110 |
+
if pbar is not None:
|
| 111 |
+
pbar.close()
|
| 112 |
+
|
| 113 |
+
# def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
| 114 |
+
# """Load file form http url, will download models if necessary.
|
| 115 |
+
|
| 116 |
+
# Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
|
| 117 |
+
|
| 118 |
+
# Args:
|
| 119 |
+
# url (str): URL to be downloaded.
|
| 120 |
+
# model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
|
| 121 |
+
# Default: None.
|
| 122 |
+
# progress (bool): Whether to show the download progress. Default: True.
|
| 123 |
+
# file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
|
| 124 |
+
|
| 125 |
+
# Returns:
|
| 126 |
+
# str: The path to the downloaded file.
|
| 127 |
+
# """
|
| 128 |
+
# if model_dir is None: # use the pytorch hub_dir
|
| 129 |
+
# hub_dir = get_dir()
|
| 130 |
+
# model_dir = os.path.join(hub_dir, 'checkpoints')
|
| 131 |
+
|
| 132 |
+
# os.makedirs(model_dir, exist_ok=True)
|
| 133 |
+
|
| 134 |
+
# parts = urlparse(url)
|
| 135 |
+
# filename = os.path.basename(parts.path)
|
| 136 |
+
# if file_name is not None:
|
| 137 |
+
# filename = file_name
|
| 138 |
+
# cached_file = os.path.abspath(os.path.join(model_dir, filename))
|
| 139 |
+
# if not os.path.exists(cached_file):
|
| 140 |
+
# print(f'Downloading: "{url}" to {cached_file}\n')
|
| 141 |
+
# download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
| 142 |
+
# return cached_file
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def torchhub_download_url_to_file(url: str, dst: str, progress: bool = True):
|
| 146 |
+
original_ctx = ssl._create_default_https_context
|
| 147 |
+
ssl._create_default_https_context = ssl._create_unverified_context # https://stackoverflow.com/questions/50236117/scraping-ssl-certificate-verify-failed-error-for-http-en-wikipedia-org
|
| 148 |
+
_torchhub_download_url_to_file(url, dst, progress=progress)
|
| 149 |
+
ssl._create_default_https_context = original_ctx
|
| 150 |
+
|
| 151 |
+
def check_local_file(local_file: str, sha256_precal: str = None, cache_hash: bool = False):
|
| 152 |
+
|
| 153 |
+
file_exists = osp.exists(local_file)
|
| 154 |
+
valid_hash, sha256_calculated = True, sha256_precal
|
| 155 |
+
|
| 156 |
+
if file_exists and sha256_precal is not None and shared.check_local_file_hash:
|
| 157 |
+
sha256_precal = sha256_precal.lower()
|
| 158 |
+
if cache_hash and local_file in shared.cache_data and shared.cache_data[local_file].lower() == sha256_precal:
|
| 159 |
+
pass
|
| 160 |
+
else:
|
| 161 |
+
sha256_calculated = calculate_sha256(local_file).lower()
|
| 162 |
+
if sha256_calculated != sha256_precal:
|
| 163 |
+
valid_hash = False
|
| 164 |
+
if cache_hash:
|
| 165 |
+
shared.cache_data[local_file] = sha256_calculated
|
| 166 |
+
shared.CACHE_UPDATED = True
|
| 167 |
+
|
| 168 |
+
return file_exists, valid_hash, sha256_calculated
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def get_filename_from_url(url: str, default: str = '') -> str:
|
| 172 |
+
m = re.search(r'/([^/?]+)[^/]*$', url)
|
| 173 |
+
if m:
|
| 174 |
+
return m.group(1)
|
| 175 |
+
return default
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def download_url_with_progressbar(url: str, path: str,):
|
| 179 |
+
if os.path.basename(path) in ('.', '') or os.path.isdir(path):
|
| 180 |
+
new_filename = get_filename_from_url(url)
|
| 181 |
+
if not new_filename:
|
| 182 |
+
raise Exception('Could not determine filename')
|
| 183 |
+
path = os.path.join(path, new_filename)
|
| 184 |
+
|
| 185 |
+
headers = {}
|
| 186 |
+
downloaded_size = 0
|
| 187 |
+
# the resume downloading here is buggy when the local file is corrupted or over-sized or intended to be replaced
|
| 188 |
+
# if os.path.isfile(path): # its actually buggy
|
| 189 |
+
# downloaded_size = os.path.getsize(path)
|
| 190 |
+
# headers['Range'] = 'bytes=%d-' % downloaded_size
|
| 191 |
+
# headers['Accept-Encoding'] = 'deflate'
|
| 192 |
+
|
| 193 |
+
r = requests.get(url, stream=True, allow_redirects=True, headers=headers)
|
| 194 |
+
if downloaded_size and r.headers.get('Accept-Ranges') != 'bytes':
|
| 195 |
+
print('Error: Webserver does not support partial downloads. Restarting from the beginning.')
|
| 196 |
+
r = requests.get(url, stream=True, allow_redirects=True)
|
| 197 |
+
downloaded_size = 0
|
| 198 |
+
total = int(r.headers.get('content-length', 0))
|
| 199 |
+
chunk_size = 1024
|
| 200 |
+
|
| 201 |
+
if r.ok:
|
| 202 |
+
with tqdm.tqdm(
|
| 203 |
+
desc=os.path.basename(path),
|
| 204 |
+
initial=downloaded_size,
|
| 205 |
+
total=total+downloaded_size,
|
| 206 |
+
unit='iB',
|
| 207 |
+
unit_scale=True,
|
| 208 |
+
unit_divisor=chunk_size,
|
| 209 |
+
) as bar:
|
| 210 |
+
with open(path, 'ab' if downloaded_size else 'wb') as f:
|
| 211 |
+
is_tty = sys.stdout.isatty()
|
| 212 |
+
downloaded_chunks = 0
|
| 213 |
+
for data in r.iter_content(chunk_size=chunk_size):
|
| 214 |
+
size = f.write(data)
|
| 215 |
+
bar.update(size)
|
| 216 |
+
|
| 217 |
+
# Fallback for non TTYs so output still shown
|
| 218 |
+
downloaded_chunks += 1
|
| 219 |
+
if not is_tty and downloaded_chunks % 1000 == 0:
|
| 220 |
+
print(bar)
|
| 221 |
+
else:
|
| 222 |
+
raise Exception(f'Couldn\'t resolve url: "{url}" (Error: {r.status_code})')
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def try_download_files(url: str,
|
| 227 |
+
files: List[str],
|
| 228 |
+
save_files = List[str],
|
| 229 |
+
sha256_pre_calculated: List[str] = None,
|
| 230 |
+
concatenate_url_filename: int = 0,
|
| 231 |
+
cache_hash: bool = False,
|
| 232 |
+
download_method: str = '',
|
| 233 |
+
gdrive_file_id: str = None):
|
| 234 |
+
|
| 235 |
+
all_successful = True
|
| 236 |
+
|
| 237 |
+
for file, savep, sha256_precal in zip(files, save_files, sha256_pre_calculated):
|
| 238 |
+
save_dir = osp.dirname(savep)
|
| 239 |
+
if not osp.exists(save_dir):
|
| 240 |
+
os.makedirs(save_dir)
|
| 241 |
+
|
| 242 |
+
file_exists, valid_hash, sha256_calculated = check_local_file(savep, sha256_precal, cache_hash=cache_hash)
|
| 243 |
+
if file_exists:
|
| 244 |
+
if valid_hash:
|
| 245 |
+
continue
|
| 246 |
+
else:
|
| 247 |
+
LOGGER.warning(f'Mismatch between local file {savep} and pre-calculated hash: "{sha256_calculated}" <-> "{sha256_precal.lower()}", it will be redownloaded...')
|
| 248 |
+
|
| 249 |
+
try:
|
| 250 |
+
if concatenate_url_filename == 1:
|
| 251 |
+
download_url = url + file
|
| 252 |
+
elif concatenate_url_filename == 2:
|
| 253 |
+
download_url = url + osp.basename(file)
|
| 254 |
+
else:
|
| 255 |
+
download_url = url
|
| 256 |
+
|
| 257 |
+
if gdrive_file_id is not None:
|
| 258 |
+
download_file_from_google_drive(gdrive_file_id, savep)
|
| 259 |
+
elif download_method == 'torch_hub':
|
| 260 |
+
LOGGER.info(f'downloading {savep} from {download_url} ...')
|
| 261 |
+
torchhub_download_url_to_file(download_url, savep)
|
| 262 |
+
else:
|
| 263 |
+
download_url_with_progressbar(download_url, savep)
|
| 264 |
+
file_exists, valid_hash, sha256_calculated = check_local_file(savep, sha256_precal, cache_hash=cache_hash)
|
| 265 |
+
if not file_exists:
|
| 266 |
+
raise Exception(f'Some how the downloaded {savep} doesnt exists.')
|
| 267 |
+
elif not valid_hash:
|
| 268 |
+
raise Exception(f'Mismatch between newly downloaded {savep} and pre-calculated hash: "{sha256_calculated}" <-> "{sha256_precal.lower()}"')
|
| 269 |
+
|
| 270 |
+
except:
|
| 271 |
+
err_msg = traceback.format_exc()
|
| 272 |
+
all_successful = False
|
| 273 |
+
LOGGER.error(err_msg)
|
| 274 |
+
LOGGER.error(f'Failed downloading {file} from {download_url}, please manually save it to {savep}')
|
| 275 |
+
|
| 276 |
+
return all_successful
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def download_and_check_files(url: str,
|
| 280 |
+
files: Union[str, List],
|
| 281 |
+
save_files = None,
|
| 282 |
+
sha256_pre_calculated: Union[str, List] = None,
|
| 283 |
+
concatenate_url_filename: int = 0,
|
| 284 |
+
archived_files: List = None,
|
| 285 |
+
archive_sha256_pre_calculated: Union[str, List] = None,
|
| 286 |
+
save_dir: str = None,
|
| 287 |
+
download_method: str = 'torch_hub',
|
| 288 |
+
gdrive_file_id: str = None):
|
| 289 |
+
|
| 290 |
+
def _wrap_up_checkinputs(files: Union[str, List], save_files: Union[str, List] = None, sha256_pre_calculated: Union[str, List] = None, save_dir: str = None):
|
| 291 |
+
'''
|
| 292 |
+
ensure they're lists with equal length
|
| 293 |
+
'''
|
| 294 |
+
if not isinstance(files, List):
|
| 295 |
+
files = [files]
|
| 296 |
+
if not isinstance(sha256_pre_calculated, List):
|
| 297 |
+
if sha256_pre_calculated is None:
|
| 298 |
+
sha256_pre_calculated = [None] * len(files)
|
| 299 |
+
else:
|
| 300 |
+
sha256_pre_calculated = [sha256_pre_calculated]
|
| 301 |
+
if save_files is None:
|
| 302 |
+
save_files = files
|
| 303 |
+
elif not isinstance(save_files, List):
|
| 304 |
+
save_files = [save_files]
|
| 305 |
+
|
| 306 |
+
assert len(files) == len(sha256_pre_calculated) == len(save_files)
|
| 307 |
+
|
| 308 |
+
if save_dir is not None:
|
| 309 |
+
_save_files = []
|
| 310 |
+
for savep in save_files:
|
| 311 |
+
_save_files.append(osp.join(save_dir, savep))
|
| 312 |
+
save_files = _save_files
|
| 313 |
+
|
| 314 |
+
return files, save_files, sha256_pre_calculated
|
| 315 |
+
|
| 316 |
+
def _all_valid(save_files: List[str] = None, sha256_pre_calculated: List[str] = None,):
|
| 317 |
+
for savep, sha256_precal in zip(save_files, sha256_pre_calculated):
|
| 318 |
+
file_exists, valid_hash, sha256_calculated = check_local_file(savep, sha256_precal, cache_hash=True)
|
| 319 |
+
if not file_exists or not valid_hash:
|
| 320 |
+
return False
|
| 321 |
+
return True
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
files, save_files, sha256_pre_calculated = _wrap_up_checkinputs(files, save_files, sha256_pre_calculated, save_dir)
|
| 325 |
+
|
| 326 |
+
if archived_files is None:
|
| 327 |
+
return try_download_files(url, files, save_files, sha256_pre_calculated, concatenate_url_filename, cache_hash=True, download_method=download_method, gdrive_file_id=gdrive_file_id)
|
| 328 |
+
|
| 329 |
+
# handle archived
|
| 330 |
+
if _all_valid(save_files, sha256_pre_calculated):
|
| 331 |
+
return [], None
|
| 332 |
+
|
| 333 |
+
if isinstance(archived_files, str):
|
| 334 |
+
archived_files = [archived_files]
|
| 335 |
+
|
| 336 |
+
# download archive files
|
| 337 |
+
tmp_downloaded_archives = [osp.join(shared.cache_dir, archive_name) for archive_name in archived_files]
|
| 338 |
+
_, _, archive_sha256_pre_calculated = _wrap_up_checkinputs(archived_files, tmp_downloaded_archives, archive_sha256_pre_calculated)
|
| 339 |
+
archive_downloaded = try_download_files(url, archived_files, tmp_downloaded_archives, archive_sha256_pre_calculated, concatenate_url_filename, cache_hash=False, download_method=download_method, gdrive_file_id=gdrive_file_id)
|
| 340 |
+
if not archive_downloaded:
|
| 341 |
+
return False
|
| 342 |
+
|
| 343 |
+
# extract archived
|
| 344 |
+
archivep = tmp_downloaded_archives[0] # todo: support multi-volume
|
| 345 |
+
extract_dir = osp.join(shared.cache_dir, 'tmp_extract')
|
| 346 |
+
os.makedirs(extract_dir, exist_ok=True)
|
| 347 |
+
LOGGER.info(f'Extracting {archivep} ...')
|
| 348 |
+
shutil.unpack_archive(archivep, extract_dir)
|
| 349 |
+
|
| 350 |
+
all_valid = True
|
| 351 |
+
for file, savep, sha256_precal in zip(files, save_files, sha256_pre_calculated):
|
| 352 |
+
unarchived = osp.join(extract_dir, file)
|
| 353 |
+
save_dir = osp.dirname(savep)
|
| 354 |
+
if not osp.exists(save_dir):
|
| 355 |
+
os.makedirs(save_dir)
|
| 356 |
+
shutil.move(unarchived, savep)
|
| 357 |
+
file_exists, valid_hash, sha256_calculated = check_local_file(savep, sha256_precal, cache_hash=True)
|
| 358 |
+
if not file_exists:
|
| 359 |
+
LOGGER.error(f'The unarchived file {savep} doesnt exists.')
|
| 360 |
+
all_valid = False
|
| 361 |
+
elif not valid_hash:
|
| 362 |
+
LOGGER.error(f'Mismatch between the unarchived {savep} and pre-calculated hash: "{sha256_calculated}" <-> "{sha256_precal.lower()}"')
|
| 363 |
+
all_valid = False
|
| 364 |
+
|
| 365 |
+
if all_valid:
|
| 366 |
+
# clean archive files
|
| 367 |
+
shutil.rmtree(extract_dir)
|
| 368 |
+
for p in tmp_downloaded_archives:
|
| 369 |
+
os.remove(p)
|
| 370 |
+
|
| 371 |
+
return all_valid
|
utils/exceptions.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class ProjectDirNotExistException(Exception):
|
| 2 |
+
pass
|
| 3 |
+
|
| 4 |
+
class ProjectLoadFailureException(Exception):
|
| 5 |
+
pass
|
| 6 |
+
|
| 7 |
+
class ProjectNotSupportedException(Exception):
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
class ImgnameNotInProjectException(Exception):
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
class NotImplementedProjException(Exception):
|
| 14 |
+
pass
|
| 15 |
+
|
| 16 |
+
class InvalidModuleConfigException(Exception):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
class InvalidProgramConfigException(Exception):
|
| 20 |
+
pass
|
utils/fontformat.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union
|
| 2 |
+
import enum
|
| 3 |
+
import re
|
| 4 |
+
import copy
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from . import shared
|
| 9 |
+
from .structures import Tuple, Union, List, Dict, Config, field, nested_dataclass
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def pt2px(pt, to_int=False) -> float:
|
| 13 |
+
if to_int:
|
| 14 |
+
return int(round(pt * shared.LDPI / 72.))
|
| 15 |
+
else:
|
| 16 |
+
return pt * shared.LDPI / 72.
|
| 17 |
+
|
| 18 |
+
def px2pt(px) -> float:
|
| 19 |
+
return px / shared.LDPI * 72.
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class LineSpacingType(enum.IntEnum):
|
| 23 |
+
Proportional = 0
|
| 24 |
+
Distance = 1
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TextAlignment(enum.IntEnum):
|
| 28 |
+
Left = 0
|
| 29 |
+
Center = 1
|
| 30 |
+
Right = 2
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
fontweight_qt5_to_qt6 = {0: 100, 12: 200, 25: 300, 50: 400, 57: 500, 63: 600, 75: 700, 81: 800, 87: 900}
|
| 34 |
+
fontweight_qt6_to_qt5 = {100: 0, 200: 12, 300: 25, 400: 50, 500: 57, 600: 63, 700: 75, 800: 81, 900: 87}
|
| 35 |
+
|
| 36 |
+
fontweight_pattern = re.compile(r'font-weight:(\d+)', re.DOTALL)
|
| 37 |
+
|
| 38 |
+
def fix_fontweight_qt(weight: Union[str, int]):
|
| 39 |
+
|
| 40 |
+
def _fix_html_fntweight(matched):
|
| 41 |
+
weight = int(matched.group(1))
|
| 42 |
+
return f'font-weight:{fix_fontweight_qt(weight)}'
|
| 43 |
+
|
| 44 |
+
if weight is None:
|
| 45 |
+
return None
|
| 46 |
+
if isinstance(weight, int):
|
| 47 |
+
if shared.FLAG_QT6 and weight < 100:
|
| 48 |
+
if weight in fontweight_qt5_to_qt6:
|
| 49 |
+
weight = fontweight_qt5_to_qt6[weight]
|
| 50 |
+
if not shared.FLAG_QT6 and weight >= 100:
|
| 51 |
+
if weight in fontweight_qt6_to_qt5:
|
| 52 |
+
weight = fontweight_qt6_to_qt5[weight]
|
| 53 |
+
if isinstance(weight, str):
|
| 54 |
+
weight = fontweight_pattern.sub(lambda matched: _fix_html_fntweight(matched), weight)
|
| 55 |
+
return weight
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@nested_dataclass
|
| 59 |
+
class FontFormat(Config):
|
| 60 |
+
|
| 61 |
+
font_family: str = shared.DEFAULT_FONT_FAMILY # to always apply shared.DEFAULT_FONT_FAMILY
|
| 62 |
+
font_size: float = 24
|
| 63 |
+
stroke_width: float = 0.
|
| 64 |
+
frgb: List = field(default_factory=lambda: [0, 0, 0])
|
| 65 |
+
srgb: List = field(default_factory=lambda: [0, 0, 0])
|
| 66 |
+
bold: bool = False
|
| 67 |
+
underline: bool = False
|
| 68 |
+
italic: bool = False
|
| 69 |
+
alignment: int = 0
|
| 70 |
+
vertical: bool = False
|
| 71 |
+
font_weight: int = None
|
| 72 |
+
line_spacing: float = 1.2
|
| 73 |
+
letter_spacing: float = 1.15
|
| 74 |
+
opacity: float = 1.
|
| 75 |
+
shadow_radius: float = 0.
|
| 76 |
+
shadow_strength: float = 1.
|
| 77 |
+
shadow_color: List = field(default_factory=lambda: [0, 0, 0])
|
| 78 |
+
shadow_offset: List = field(default_factory=lambda: [0., 0.])
|
| 79 |
+
gradient_enabled: bool = False
|
| 80 |
+
gradient_start_color: List = field(default_factory=lambda: [0, 0, 0])
|
| 81 |
+
gradient_end_color: List = field(default_factory=lambda: [255, 255, 255])
|
| 82 |
+
gradient_angle: float = 0.
|
| 83 |
+
gradient_size: float = 1.0
|
| 84 |
+
_style_name: str = ''
|
| 85 |
+
line_spacing_type: int = LineSpacingType.Proportional
|
| 86 |
+
|
| 87 |
+
deprecated_attributes: dict = field(default_factory = lambda: dict())
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def size_pt(self):
|
| 91 |
+
return px2pt(self.font_size)
|
| 92 |
+
|
| 93 |
+
def __post_init__(self):
|
| 94 |
+
da = self.deprecated_attributes
|
| 95 |
+
if len(da) > 0:
|
| 96 |
+
if 'size' in da:
|
| 97 |
+
self.font_size = pt2px(da['size'])
|
| 98 |
+
if 'weight' in da:
|
| 99 |
+
self.font_weight = da['weight']
|
| 100 |
+
if 'family' in da:
|
| 101 |
+
self.font_family = da['family']
|
| 102 |
+
|
| 103 |
+
self.font_weight = fix_fontweight_qt(self.font_weight)
|
| 104 |
+
self.deprecated_attributes = {}
|
| 105 |
+
|
| 106 |
+
def deepcopy(self):
|
| 107 |
+
fmt_copyed: FontFormat = None
|
| 108 |
+
fmt_copyed = copy.deepcopy(self)
|
| 109 |
+
return fmt_copyed
|
| 110 |
+
|
| 111 |
+
def merge(self, target: Config, compare: bool = False):
|
| 112 |
+
if id(self) == id(target):
|
| 113 |
+
return set()
|
| 114 |
+
tgt_keys = target.annotations_set()
|
| 115 |
+
updated_keys = set()
|
| 116 |
+
for key in tgt_keys:
|
| 117 |
+
if not hasattr(self, key):
|
| 118 |
+
continue
|
| 119 |
+
if compare:
|
| 120 |
+
if key != '_style_name':
|
| 121 |
+
if isinstance(target[key], np.ndarray):
|
| 122 |
+
is_diff = np.any(self[key] != target[key])
|
| 123 |
+
else:
|
| 124 |
+
is_diff = self[key] != target[key]
|
| 125 |
+
if is_diff:
|
| 126 |
+
self.update(key, copy.deepcopy(target[key]))
|
| 127 |
+
updated_keys.add(key)
|
| 128 |
+
else:
|
| 129 |
+
self.update(key, copy.deepcopy(target[key]))
|
| 130 |
+
return updated_keys
|
| 131 |
+
|
| 132 |
+
def foreground_color(self):
|
| 133 |
+
return [int(round(x)) for x in self.frgb]
|
| 134 |
+
|
| 135 |
+
def stroke_color(self):
|
| 136 |
+
return [int(round(x)) for x in self.srgb]
|
utils/imgproc_utils.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import cv2
|
| 3 |
+
import random
|
| 4 |
+
from typing import List, Tuple, Union
|
| 5 |
+
|
| 6 |
+
def hex2bgr(hex):
|
| 7 |
+
gmask = 254 << 8
|
| 8 |
+
rmask = 254
|
| 9 |
+
b = hex >> 16
|
| 10 |
+
g = (hex & gmask) >> 8
|
| 11 |
+
r = hex & rmask
|
| 12 |
+
return np.stack([b, g, r]).transpose()
|
| 13 |
+
|
| 14 |
+
def union_area(bboxa, bboxb):
|
| 15 |
+
x1 = max(bboxa[0], bboxb[0])
|
| 16 |
+
y1 = max(bboxa[1], bboxb[1])
|
| 17 |
+
x2 = min(bboxa[2], bboxb[2])
|
| 18 |
+
y2 = min(bboxa[3], bboxb[3])
|
| 19 |
+
if y2 < y1 or x2 < x1:
|
| 20 |
+
return -1
|
| 21 |
+
return (y2 - y1) * (x2 - x1)
|
| 22 |
+
|
| 23 |
+
def get_yololabel_strings(clslist, labellist):
|
| 24 |
+
content = ''
|
| 25 |
+
for cls, xywh in zip(clslist, labellist):
|
| 26 |
+
content += str(int(cls)) + ' ' + ' '.join([str(e) for e in xywh]) + '\n'
|
| 27 |
+
if len(content) != 0:
|
| 28 |
+
content = content[:-1]
|
| 29 |
+
return content
|
| 30 |
+
|
| 31 |
+
# 4 points bbox to 8 points polygon
|
| 32 |
+
def xywh2xyxypoly(xywh, to_int=True):
|
| 33 |
+
xyxypoly = np.tile(xywh[:, [0, 1]], 4)
|
| 34 |
+
xyxypoly[:, [2, 4]] += xywh[:, [2]]
|
| 35 |
+
xyxypoly[:, [5, 7]] += xywh[:, [3]]
|
| 36 |
+
if to_int:
|
| 37 |
+
xyxypoly = xyxypoly.astype(np.int64)
|
| 38 |
+
return xyxypoly
|
| 39 |
+
|
| 40 |
+
def xyxy2yolo(xyxy, w: int, h: int):
|
| 41 |
+
if xyxy == [] or xyxy == np.array([]) or len(xyxy) == 0:
|
| 42 |
+
return None
|
| 43 |
+
if isinstance(xyxy, list):
|
| 44 |
+
xyxy = np.array(xyxy)
|
| 45 |
+
if len(xyxy.shape) == 1:
|
| 46 |
+
xyxy = np.array([xyxy])
|
| 47 |
+
yolo = np.copy(xyxy).astype(np.float64)
|
| 48 |
+
yolo[:, [0, 2]] = yolo[:, [0, 2]] / w
|
| 49 |
+
yolo[:, [1, 3]] = yolo[:, [1, 3]] / h
|
| 50 |
+
yolo[:, [2, 3]] -= yolo[:, [0, 1]]
|
| 51 |
+
yolo[:, [0, 1]] += yolo[:, [2, 3]] / 2
|
| 52 |
+
return yolo
|
| 53 |
+
|
| 54 |
+
def yolo_xywh2xyxy(xywh: np.array, w: int, h: int, to_int=True):
|
| 55 |
+
if xywh is None:
|
| 56 |
+
return None
|
| 57 |
+
if len(xywh) == 0:
|
| 58 |
+
return None
|
| 59 |
+
if len(xywh.shape) == 1:
|
| 60 |
+
xywh = np.array([xywh])
|
| 61 |
+
xywh[:, [0, 2]] *= w
|
| 62 |
+
xywh[:, [1, 3]] *= h
|
| 63 |
+
xywh[:, [0, 1]] -= xywh[:, [2, 3]] / 2
|
| 64 |
+
xywh[:, [2, 3]] += xywh[:, [0, 1]]
|
| 65 |
+
if to_int:
|
| 66 |
+
xywh = xywh.astype(np.int64)
|
| 67 |
+
return xywh
|
| 68 |
+
|
| 69 |
+
def rotate_polygons(center, polygons, rotation, new_center=None, to_int=True):
|
| 70 |
+
if new_center is None:
|
| 71 |
+
new_center = center
|
| 72 |
+
rotation = np.deg2rad(rotation)
|
| 73 |
+
s, c = np.sin(rotation), np.cos(rotation)
|
| 74 |
+
polygons = polygons.astype(np.float32)
|
| 75 |
+
|
| 76 |
+
polygons[:, 1::2] -= center[1]
|
| 77 |
+
polygons[:, ::2] -= center[0]
|
| 78 |
+
rotated = np.copy(polygons)
|
| 79 |
+
rotated[:, 1::2] = polygons[:, 1::2] * c - polygons[:, ::2] * s
|
| 80 |
+
rotated[:, ::2] = polygons[:, 1::2] * s + polygons[:, ::2] * c
|
| 81 |
+
rotated[:, 1::2] += new_center[1]
|
| 82 |
+
rotated[:, ::2] += new_center[0]
|
| 83 |
+
if to_int:
|
| 84 |
+
return rotated.astype(np.int64)
|
| 85 |
+
return rotated
|
| 86 |
+
|
| 87 |
+
def letterbox(im, new_shape=(640, 640), color=(0, 0, 0), auto=False, scaleFill=False, scaleup=True, stride=128):
|
| 88 |
+
# Resize and pad image while meeting stride-multiple constraints
|
| 89 |
+
shape = im.shape[:2] # current shape [height, width]
|
| 90 |
+
if not isinstance(new_shape, tuple):
|
| 91 |
+
new_shape = (new_shape, new_shape)
|
| 92 |
+
|
| 93 |
+
# Scale ratio (new / old)
|
| 94 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
| 95 |
+
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
| 96 |
+
r = min(r, 1.0)
|
| 97 |
+
|
| 98 |
+
# Compute padding
|
| 99 |
+
ratio = r, r # width, height ratios
|
| 100 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
| 101 |
+
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
|
| 102 |
+
if auto: # minimum rectangle
|
| 103 |
+
dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding
|
| 104 |
+
elif scaleFill: # stretch
|
| 105 |
+
dw, dh = 0.0, 0.0
|
| 106 |
+
new_unpad = (new_shape[1], new_shape[0])
|
| 107 |
+
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
|
| 108 |
+
|
| 109 |
+
# dw /= 2 # divide padding into 2 sides
|
| 110 |
+
# dh /= 2
|
| 111 |
+
dh, dw = int(dh), int(dw)
|
| 112 |
+
|
| 113 |
+
if shape[::-1] != new_unpad: # resize
|
| 114 |
+
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
|
| 115 |
+
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
|
| 116 |
+
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
|
| 117 |
+
im = cv2.copyMakeBorder(im, 0, dh, 0, dw, cv2.BORDER_CONSTANT, value=color) # add border
|
| 118 |
+
return im, ratio, (dw, dh)
|
| 119 |
+
|
| 120 |
+
def resize_keepasp(im, new_shape=640, scaleup=True, interpolation=cv2.INTER_LINEAR, stride=None):
|
| 121 |
+
shape = im.shape[:2] # current shape [height, width]
|
| 122 |
+
|
| 123 |
+
if new_shape is not None:
|
| 124 |
+
if not isinstance(new_shape, tuple):
|
| 125 |
+
new_shape = (new_shape, new_shape)
|
| 126 |
+
else:
|
| 127 |
+
new_shape = shape
|
| 128 |
+
|
| 129 |
+
# Scale ratio (new / old)
|
| 130 |
+
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
|
| 131 |
+
if not scaleup: # only scale down, do not scale up (for better val mAP)
|
| 132 |
+
r = min(r, 1.0)
|
| 133 |
+
|
| 134 |
+
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
| 135 |
+
|
| 136 |
+
if stride is not None:
|
| 137 |
+
h, w = new_unpad
|
| 138 |
+
if h % stride != 0 :
|
| 139 |
+
new_h = (stride - (h % stride)) + h
|
| 140 |
+
else :
|
| 141 |
+
new_h = h
|
| 142 |
+
if w % stride != 0 :
|
| 143 |
+
new_w = (stride - (w % stride)) + w
|
| 144 |
+
else :
|
| 145 |
+
new_w = w
|
| 146 |
+
new_unpad = (new_h, new_w)
|
| 147 |
+
|
| 148 |
+
if shape[::-1] != new_unpad: # resize
|
| 149 |
+
im = cv2.resize(im, new_unpad, interpolation=interpolation)
|
| 150 |
+
return im
|
| 151 |
+
|
| 152 |
+
def expand_textwindow(img_size, xyxy, expand_r=8, shrink=False):
|
| 153 |
+
im_h, im_w = img_size[:2]
|
| 154 |
+
x1, y1 , x2, y2 = xyxy
|
| 155 |
+
w = x2 - x1
|
| 156 |
+
h = y2 - y1
|
| 157 |
+
paddings = int(round((max(h, w) * 0.25 + min(h, w) * 0.75) / expand_r))
|
| 158 |
+
if shrink:
|
| 159 |
+
paddings *= -1
|
| 160 |
+
x1, y1 = max(0, x1 - paddings), max(0, y1 - paddings)
|
| 161 |
+
x2, y2 = min(im_w-1, x2+paddings), min(im_h-1, y2+paddings)
|
| 162 |
+
return [x1, y1, x2, y2]
|
| 163 |
+
|
| 164 |
+
def enlarge_window(rect, im_w, im_h, ratio=2.5, aspect_ratio=1.0) -> List:
|
| 165 |
+
assert ratio > 1.0
|
| 166 |
+
|
| 167 |
+
x1, y1, x2, y2 = rect
|
| 168 |
+
w = x2 - x1
|
| 169 |
+
h = y2 - y1
|
| 170 |
+
|
| 171 |
+
if w <= 0 or h <= 0:
|
| 172 |
+
return [0, 0, 0, 0]
|
| 173 |
+
|
| 174 |
+
# https://numpy.org/doc/stable/reference/generated/numpy.roots.html
|
| 175 |
+
coeff = [aspect_ratio, w+h*aspect_ratio, (1-ratio)*w*h]
|
| 176 |
+
roots = np.roots(coeff)
|
| 177 |
+
roots.sort()
|
| 178 |
+
delta = int(round(roots[-1] / 2 ))
|
| 179 |
+
delta_w = int(delta * aspect_ratio)
|
| 180 |
+
delta_w = min(x1, im_w - x2, delta_w)
|
| 181 |
+
delta = min(y1, im_h - y2, delta)
|
| 182 |
+
rect = np.array([x1-delta_w, y1-delta, x2+delta_w, y2+delta], dtype=np.int64)
|
| 183 |
+
rect[::2] = np.clip(rect[::2], 0, im_w)
|
| 184 |
+
rect[1::2] = np.clip(rect[1::2], 0, im_h)
|
| 185 |
+
return rect.tolist()
|
| 186 |
+
|
| 187 |
+
def draw_connected_labels(num_labels, labels, stats, centroids, names="draw_connected_labels", skip_background=True):
|
| 188 |
+
labdraw = np.zeros((labels.shape[0], labels.shape[1], 3), dtype=np.uint8)
|
| 189 |
+
max_ind = 0
|
| 190 |
+
if isinstance(num_labels, int):
|
| 191 |
+
num_labels = range(num_labels)
|
| 192 |
+
|
| 193 |
+
# for ind, lab in enumerate((range(num_labels))):
|
| 194 |
+
for lab in num_labels:
|
| 195 |
+
if skip_background and lab == 0:
|
| 196 |
+
continue
|
| 197 |
+
randcolor = (random.randint(0,255), random.randint(0,255), random.randint(0,255))
|
| 198 |
+
labdraw[np.where(labels==lab)] = randcolor
|
| 199 |
+
maxr, minr = 0.5, 0.001
|
| 200 |
+
maxw, maxh = stats[max_ind][2] * maxr, stats[max_ind][3] * maxr
|
| 201 |
+
minarea = labdraw.shape[0] * labdraw.shape[1] * minr
|
| 202 |
+
|
| 203 |
+
stat = stats[lab]
|
| 204 |
+
bboxarea = stat[2] * stat[3]
|
| 205 |
+
if stat[2] < maxw and stat[3] < maxh and bboxarea > minarea:
|
| 206 |
+
pix = np.zeros((labels.shape[0], labels.shape[1]), dtype=np.uint8)
|
| 207 |
+
pix[np.where(labels==lab)] = 255
|
| 208 |
+
|
| 209 |
+
rect = cv2.minAreaRect(cv2.findNonZero(pix))
|
| 210 |
+
box = np.int0(cv2.boxPoints(rect))
|
| 211 |
+
labdraw = cv2.drawContours(labdraw, [box], 0, randcolor, 2)
|
| 212 |
+
labdraw = cv2.circle(labdraw, (int(centroids[lab][0]),int(centroids[lab][1])), radius=5, color=(random.randint(0,255), random.randint(0,255), random.randint(0,255)), thickness=-1)
|
| 213 |
+
|
| 214 |
+
cv2.imshow(names, labdraw)
|
| 215 |
+
return labdraw
|
| 216 |
+
|
| 217 |
+
def rotate_image(mat: np.ndarray, angle: float) -> np.ndarray:
|
| 218 |
+
"""
|
| 219 |
+
Rotates an image (angle in degrees) and expands image to avoid cropping
|
| 220 |
+
# https://stackoverflow.com/questions/43892506/opencv-python-rotate-image-without-cropping-sides
|
| 221 |
+
"""
|
| 222 |
+
|
| 223 |
+
height, width = mat.shape[:2] # image shape has 3 dimensions
|
| 224 |
+
image_center = (width/2, height/2) # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape
|
| 225 |
+
|
| 226 |
+
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.)
|
| 227 |
+
|
| 228 |
+
# rotation calculates the cos and sin, taking absolutes of those.
|
| 229 |
+
abs_cos = abs(rotation_mat[0,0])
|
| 230 |
+
abs_sin = abs(rotation_mat[0,1])
|
| 231 |
+
|
| 232 |
+
# find the new width and height bounds
|
| 233 |
+
bound_w = int(height * abs_sin + width * abs_cos)
|
| 234 |
+
bound_h = int(height * abs_cos + width * abs_sin)
|
| 235 |
+
|
| 236 |
+
# subtract old image center (bringing image back to origo) and adding the new image center coordinates
|
| 237 |
+
rotation_mat[0, 2] += bound_w/2 - image_center[0]
|
| 238 |
+
rotation_mat[1, 2] += bound_h/2 - image_center[1]
|
| 239 |
+
|
| 240 |
+
# rotate image with the new bounds and translated rotation matrix
|
| 241 |
+
rotated_mat = cv2.warpAffine(mat, rotation_mat, (bound_w, bound_h))
|
| 242 |
+
return rotated_mat
|
| 243 |
+
|
| 244 |
+
def color_difference(rgb1: List, rgb2: List) -> float:
|
| 245 |
+
# https://en.wikipedia.org/wiki/Color_difference#CIE76
|
| 246 |
+
color1 = np.array(rgb1, dtype=np.uint8).reshape(1, 1, 3)
|
| 247 |
+
color2 = np.array(rgb2, dtype=np.uint8).reshape(1, 1, 3)
|
| 248 |
+
diff = cv2.cvtColor(color1, cv2.COLOR_RGB2LAB).astype(np.float64) - cv2.cvtColor(color2, cv2.COLOR_RGB2LAB).astype(np.float64)
|
| 249 |
+
diff[..., 0] *= 0.392
|
| 250 |
+
diff = np.linalg.norm(diff, axis=2)
|
| 251 |
+
return diff.item()
|
| 252 |
+
|
| 253 |
+
def extract_ballon_region(img: np.ndarray, ballon_rect: List, show_process=False, enlarge_ratio=2.0, cal_region_rect=False) -> Tuple[np.ndarray, int, List]:
|
| 254 |
+
WHITE = (255, 255, 255)
|
| 255 |
+
BLACK = (0, 0, 0)
|
| 256 |
+
|
| 257 |
+
x1, y1, x2, y2 = ballon_rect[0], ballon_rect[1], \
|
| 258 |
+
ballon_rect[2] + ballon_rect[0], ballon_rect[3] + ballon_rect[1]
|
| 259 |
+
if enlarge_ratio > 1:
|
| 260 |
+
x1, y1, x2, y2 = enlarge_window([x1, y1, x2, y2], img.shape[1], img.shape[0], enlarge_ratio, aspect_ratio=ballon_rect[3] / ballon_rect[2])
|
| 261 |
+
|
| 262 |
+
img = img[y1:y2, x1:x2].copy()
|
| 263 |
+
|
| 264 |
+
kernel = np.ones((3,3),np.uint8)
|
| 265 |
+
orih, oriw = img.shape[0], img.shape[1]
|
| 266 |
+
scaleR = 1
|
| 267 |
+
if orih > 300 and oriw > 300:
|
| 268 |
+
scaleR = 0.6
|
| 269 |
+
elif orih < 120 or oriw < 120:
|
| 270 |
+
scaleR = 1.4
|
| 271 |
+
|
| 272 |
+
if scaleR != 1:
|
| 273 |
+
h, w = img.shape[0], img.shape[1]
|
| 274 |
+
orimg = np.copy(img)
|
| 275 |
+
img = cv2.resize(img, (int(w*scaleR), int(h*scaleR)), interpolation=cv2.INTER_AREA)
|
| 276 |
+
h, w = img.shape[0], img.shape[1]
|
| 277 |
+
img_area = h * w
|
| 278 |
+
|
| 279 |
+
cpimg = cv2.GaussianBlur(img,(3,3),cv2.BORDER_DEFAULT)
|
| 280 |
+
detected_edges = cv2.Canny(cpimg, 70, 140, L2gradient=True, apertureSize=3)
|
| 281 |
+
cv2.rectangle(detected_edges, (0, 0), (w-1, h-1), WHITE, 1, cv2.LINE_8)
|
| 282 |
+
cons, hiers = cv2.findContours(detected_edges, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
| 283 |
+
cv2.rectangle(detected_edges, (0, 0), (w-1, h-1), BLACK, 1, cv2.LINE_8)
|
| 284 |
+
|
| 285 |
+
ballon_mask, outer_index = np.zeros((h, w), np.uint8), -1
|
| 286 |
+
min_retval = np.inf
|
| 287 |
+
mask = np.zeros((h, w), np.uint8)
|
| 288 |
+
difres = 10
|
| 289 |
+
seedpnt = (int(w/2), int(h/2))
|
| 290 |
+
for ii in range(len(cons)):
|
| 291 |
+
rect = cv2.boundingRect(cons[ii])
|
| 292 |
+
if rect[2]*rect[3] < img_area*0.4:
|
| 293 |
+
continue
|
| 294 |
+
|
| 295 |
+
mask = cv2.drawContours(mask, cons, ii, (255), 2)
|
| 296 |
+
cpmask = np.copy(mask)
|
| 297 |
+
cv2.rectangle(mask, (0, 0), (w-1, h-1), WHITE, 1, cv2.LINE_8)
|
| 298 |
+
retval, _, _, rect = cv2.floodFill(cpmask, mask=None, seedPoint=seedpnt, flags=4, newVal=(127), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
|
| 299 |
+
|
| 300 |
+
if retval <= img_area * 0.3:
|
| 301 |
+
mask = cv2.drawContours(mask, cons, ii, (0), 2)
|
| 302 |
+
if retval < min_retval and retval > img_area * 0.3:
|
| 303 |
+
min_retval = retval
|
| 304 |
+
ballon_mask = cpmask
|
| 305 |
+
|
| 306 |
+
ballon_mask = 127 - ballon_mask
|
| 307 |
+
ballon_mask = cv2.dilate(ballon_mask, kernel,iterations = 1)
|
| 308 |
+
ballon_area, _, _, rect = cv2.floodFill(ballon_mask, mask=None, seedPoint=seedpnt, flags=4, newVal=(30), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
|
| 309 |
+
ballon_mask = 30 - ballon_mask
|
| 310 |
+
retval, ballon_mask = cv2.threshold(ballon_mask, 1, 255, cv2.THRESH_BINARY)
|
| 311 |
+
ballon_mask = cv2.bitwise_not(ballon_mask, ballon_mask)
|
| 312 |
+
|
| 313 |
+
box_kernel = int(np.sqrt(ballon_area) / 30)
|
| 314 |
+
if box_kernel > 1:
|
| 315 |
+
box_kernel = np.ones((box_kernel,box_kernel),np.uint8)
|
| 316 |
+
ballon_mask = cv2.dilate(ballon_mask, box_kernel, iterations = 1)
|
| 317 |
+
ballon_mask = cv2.erode(ballon_mask, box_kernel, iterations = 1)
|
| 318 |
+
|
| 319 |
+
if scaleR != 1:
|
| 320 |
+
img = orimg
|
| 321 |
+
ballon_mask = cv2.resize(ballon_mask, (oriw, orih))
|
| 322 |
+
|
| 323 |
+
if show_process:
|
| 324 |
+
cv2.imshow('ballon_mask', ballon_mask)
|
| 325 |
+
cv2.imshow('img', img)
|
| 326 |
+
cv2.waitKey(0)
|
| 327 |
+
if cal_region_rect:
|
| 328 |
+
return ballon_mask, (ballon_mask > 0).sum(), [x1, y1, x2, y2], cv2.boundingRect(ballon_mask)
|
| 329 |
+
return ballon_mask, (ballon_mask > 0).sum(), [x1, y1, x2, y2]
|
| 330 |
+
|
| 331 |
+
def square_pad_resize(img: np.ndarray, tgt_size: int):
|
| 332 |
+
h, w = img.shape[:2]
|
| 333 |
+
pad_h, pad_w = 0, 0
|
| 334 |
+
|
| 335 |
+
# make square image
|
| 336 |
+
if w < h:
|
| 337 |
+
pad_w = h - w
|
| 338 |
+
w += pad_w
|
| 339 |
+
elif h < w:
|
| 340 |
+
pad_h = w - h
|
| 341 |
+
h += pad_h
|
| 342 |
+
|
| 343 |
+
pad_size = tgt_size - h
|
| 344 |
+
if pad_size > 0:
|
| 345 |
+
pad_h += pad_size
|
| 346 |
+
pad_w += pad_size
|
| 347 |
+
|
| 348 |
+
if pad_h > 0 or pad_w > 0:
|
| 349 |
+
img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT)
|
| 350 |
+
|
| 351 |
+
down_scale_ratio = tgt_size / img.shape[0]
|
| 352 |
+
assert down_scale_ratio <= 1
|
| 353 |
+
if down_scale_ratio < 1:
|
| 354 |
+
img = cv2.resize(img, (tgt_size, tgt_size), interpolation=cv2.INTER_AREA)
|
| 355 |
+
|
| 356 |
+
return img, down_scale_ratio, pad_h, pad_w
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def get_block_mask(xywh: List, mask_array: np.ndarray, angle: int):
|
| 361 |
+
x, y, w, h = xywh
|
| 362 |
+
im_h, im_w = mask_array.shape[:2]
|
| 363 |
+
|
| 364 |
+
if angle != 0:
|
| 365 |
+
cx, cy = x + int(round(w / 2)), y + int(round(h / 2))
|
| 366 |
+
poly = xywh2xyxypoly(np.array([[x, y, w, h]]))
|
| 367 |
+
poly = rotate_polygons([cx, cy], poly, -angle)
|
| 368 |
+
|
| 369 |
+
x1, x2 = np.min(poly[..., ::2]), np.max(poly[..., ::2])
|
| 370 |
+
y1, y2 = np.min(poly[..., 1::2]), np.max(poly[..., 1::2])
|
| 371 |
+
|
| 372 |
+
if x2 < 0 or x2 - x1 < 2 or x1 >= im_w - 1 \
|
| 373 |
+
or y2 < 0 or y2 - y1 < 2 or y1 >= im_h - 1:
|
| 374 |
+
return None, None
|
| 375 |
+
else:
|
| 376 |
+
poly[..., ::2] -= cx - int((x2 - x1) / 2)
|
| 377 |
+
poly[..., 1::2] -= cy - int((y2 - y1) / 2)
|
| 378 |
+
itmsk = np.zeros((y2 - y1, x2 - x1), np.uint8)
|
| 379 |
+
|
| 380 |
+
cv2.fillPoly(itmsk, poly.reshape(-1, 4, 2), color=(255))
|
| 381 |
+
px1, px2, py1, py2 = 0, itmsk.shape[1], 0, itmsk.shape[0]
|
| 382 |
+
if x1 < 0:
|
| 383 |
+
px1 = -x1
|
| 384 |
+
x1 = 0
|
| 385 |
+
if x2 > im_w:
|
| 386 |
+
px2 = im_w - x2
|
| 387 |
+
x2 = im_w
|
| 388 |
+
if y1 < 0:
|
| 389 |
+
py1 = -y1
|
| 390 |
+
y1 = 0
|
| 391 |
+
if y2 > im_h:
|
| 392 |
+
py2 = im_h - y2
|
| 393 |
+
y2 = im_h
|
| 394 |
+
itmsk = itmsk[py1: py2, px1: px2]
|
| 395 |
+
msk = cv2.bitwise_and(mask_array[y1: y2, x1: x2], itmsk)
|
| 396 |
+
else:
|
| 397 |
+
x1, y1, x2, y2 = x, y, x+w, y+h
|
| 398 |
+
if x2 < 0 or x2 - x1 < 2 or x1 >= im_w - 1 \
|
| 399 |
+
or y2 < 0 or y2 - y1 < 2 or y1 >= im_h - 1:
|
| 400 |
+
return None, None
|
| 401 |
+
else:
|
| 402 |
+
if x1 < 0:
|
| 403 |
+
x1 = 0
|
| 404 |
+
if x2 > im_w:
|
| 405 |
+
x2 = im_w
|
| 406 |
+
if y1 < 0:
|
| 407 |
+
y1 = 0
|
| 408 |
+
if y2 > im_h:
|
| 409 |
+
y2 = im_h
|
| 410 |
+
msk = mask_array[y1: y2, x1: x2]
|
| 411 |
+
|
| 412 |
+
return msk, [x1, y1, x2, y2]
|
| 413 |
+
|
utils/io_utils.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json, os, sys, time, io
|
| 2 |
+
import os.path as osp
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import importlib
|
| 5 |
+
from typing import List, Dict, Callable, Union
|
| 6 |
+
import base64
|
| 7 |
+
import traceback
|
| 8 |
+
|
| 9 |
+
from .logger import logger as LOGGER
|
| 10 |
+
import requests
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import PIL
|
| 13 |
+
import cv2
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pillow_jxl
|
| 16 |
+
from natsort import natsorted
|
| 17 |
+
|
| 18 |
+
IMG_EXT = ['.bmp', '.jpg', '.png', '.jpeg', '.webp', '.jxl']
|
| 19 |
+
|
| 20 |
+
NP_INT_TYPES = (np.int_, np.int8, np.int16, np.int32, np.int64, np.uint, np.uint8, np.uint16, np.uint32, np.uint64)
|
| 21 |
+
if int(np.version.full_version.split('.')[0]) == 1:
|
| 22 |
+
NP_BOOL_TYPES = (np.bool_, np.bool8)
|
| 23 |
+
NP_FLOAT_TYPES = (np.float_, np.float16, np.float32, np.float64)
|
| 24 |
+
else:
|
| 25 |
+
NP_BOOL_TYPES = (np.bool_, np.bool)
|
| 26 |
+
NP_FLOAT_TYPES = (np.float16, np.float32, np.float64)
|
| 27 |
+
|
| 28 |
+
def to_dict(obj):
|
| 29 |
+
return json.loads(json.dumps(obj, default=lambda o: o.__dict__, ensure_ascii=False))
|
| 30 |
+
|
| 31 |
+
def serialize_np(obj):
|
| 32 |
+
if isinstance(obj, np.ndarray):
|
| 33 |
+
return obj.tolist()
|
| 34 |
+
elif isinstance(obj, np.ScalarType):
|
| 35 |
+
if isinstance(obj, NP_BOOL_TYPES):
|
| 36 |
+
return bool(obj)
|
| 37 |
+
elif isinstance(obj, NP_FLOAT_TYPES):
|
| 38 |
+
return float(obj)
|
| 39 |
+
elif isinstance(obj, NP_INT_TYPES):
|
| 40 |
+
return int(obj)
|
| 41 |
+
return obj
|
| 42 |
+
|
| 43 |
+
def json_dump_nested_obj(obj, **kwargs):
|
| 44 |
+
def _default(obj):
|
| 45 |
+
if isinstance(obj, (np.ndarray, np.ScalarType)):
|
| 46 |
+
return serialize_np(obj)
|
| 47 |
+
return obj.__dict__
|
| 48 |
+
return json.dumps(obj, default=lambda o: _default(o), ensure_ascii=False, **kwargs)
|
| 49 |
+
|
| 50 |
+
# https://stackoverflow.com/questions/26646362/numpy-array-is-not-json-serializable
|
| 51 |
+
class NumpyEncoder(json.JSONEncoder):
|
| 52 |
+
def default(self, obj):
|
| 53 |
+
if isinstance(obj, (np.ndarray, np.ScalarType)):
|
| 54 |
+
return serialize_np(obj)
|
| 55 |
+
return json.JSONEncoder.default(self, obj)
|
| 56 |
+
|
| 57 |
+
def find_all_imgs(img_dir, abs_path=False, sort=False):
|
| 58 |
+
imglist = []
|
| 59 |
+
for filename in os.listdir(img_dir):
|
| 60 |
+
file_suffix = Path(filename).suffix
|
| 61 |
+
if file_suffix.lower() not in IMG_EXT:
|
| 62 |
+
continue
|
| 63 |
+
if abs_path:
|
| 64 |
+
imglist.append(osp.join(img_dir, filename))
|
| 65 |
+
else:
|
| 66 |
+
imglist.append(filename)
|
| 67 |
+
|
| 68 |
+
if sort:
|
| 69 |
+
imglist = natsorted(imglist)
|
| 70 |
+
|
| 71 |
+
return imglist
|
| 72 |
+
|
| 73 |
+
def find_all_files_recursive(tgt_dir: Union[List, str], ext: Union[List, set], exclude_dirs=None):
|
| 74 |
+
if isinstance(tgt_dir, str):
|
| 75 |
+
tgt_dir = [tgt_dir]
|
| 76 |
+
|
| 77 |
+
if exclude_dirs is None:
|
| 78 |
+
exclude_dirs = set()
|
| 79 |
+
|
| 80 |
+
filelst = []
|
| 81 |
+
for d in tgt_dir:
|
| 82 |
+
for root, _, files in os.walk(d):
|
| 83 |
+
if osp.basename(root) in exclude_dirs:
|
| 84 |
+
continue
|
| 85 |
+
for f in files:
|
| 86 |
+
if Path(f).suffix.lower() in ext:
|
| 87 |
+
filelst.append(osp.join(root, f))
|
| 88 |
+
|
| 89 |
+
return filelst
|
| 90 |
+
|
| 91 |
+
def imread(imgpath, read_type=cv2.IMREAD_COLOR, max_retry_limit=5, retry_interval=0.1):
|
| 92 |
+
if not osp.exists(imgpath):
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
num_tries = 0
|
| 96 |
+
while True:
|
| 97 |
+
try:
|
| 98 |
+
img = Image.open(imgpath)
|
| 99 |
+
if read_type != cv2.IMREAD_GRAYSCALE:
|
| 100 |
+
img = img.convert('RGB')
|
| 101 |
+
img = np.array(img)
|
| 102 |
+
break
|
| 103 |
+
except PIL.UnidentifiedImageError as e:
|
| 104 |
+
# IMG I/O thread might not finished yet
|
| 105 |
+
num_tries += 1
|
| 106 |
+
if max_retry_limit is not None and num_tries >= max_retry_limit:
|
| 107 |
+
LOGGER.exception(e)
|
| 108 |
+
return None
|
| 109 |
+
LOGGER.warning(f'PIL.UnidentifiedImageError: failed to read {imgpath}, retries: {num_tries} / {max_retry_limit}')
|
| 110 |
+
time.sleep(retry_interval)
|
| 111 |
+
|
| 112 |
+
if read_type == cv2.IMREAD_GRAYSCALE:
|
| 113 |
+
if img.ndim == 3:
|
| 114 |
+
if img.shape[-1] == 3:
|
| 115 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| 116 |
+
elif img.shape[-1] == 4:
|
| 117 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2GRAY)
|
| 118 |
+
elif img.shape[-1] == 1:
|
| 119 |
+
img = img[..., 0]
|
| 120 |
+
else:
|
| 121 |
+
raise
|
| 122 |
+
return img
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def imwrite(img_path, img, ext='.png', quality=100, jxl_encode_effort=3):
|
| 126 |
+
# cv2 writing is faster than PIL
|
| 127 |
+
suffix = Path(img_path).suffix
|
| 128 |
+
ext = ext.lower()
|
| 129 |
+
assert ext in IMG_EXT
|
| 130 |
+
if suffix != '':
|
| 131 |
+
img_path = img_path.replace(suffix, ext)
|
| 132 |
+
else:
|
| 133 |
+
img_path += ext
|
| 134 |
+
encode_param = None
|
| 135 |
+
if ext in {'.jpg', '.jpeg'}:
|
| 136 |
+
encode_param = [cv2.IMWRITE_JPEG_QUALITY, quality]
|
| 137 |
+
elif ext == '.webp':
|
| 138 |
+
encode_param = [cv2.IMWRITE_WEBP_QUALITY, quality]
|
| 139 |
+
if ext == '.jxl':
|
| 140 |
+
# jxl_encode_effort: https://github.com/Isotr0py/pillow-jpegxl-plugin/issues/23
|
| 141 |
+
# higher values theoretically produce smaller files at the expense of time, 3 seems to strike a balance
|
| 142 |
+
lossless = quality > 99 # quality=100, lossless=False seems to result in larger file compared with lossless=True
|
| 143 |
+
Image.fromarray(img).save(img_path, quality=quality, lossless=lossless, effort=jxl_encode_effort)
|
| 144 |
+
else:
|
| 145 |
+
if len(img.shape) == 3:
|
| 146 |
+
if img.shape[-1] == 3:
|
| 147 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
| 148 |
+
elif img.shape[-1] == 4:
|
| 149 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2BGRA)
|
| 150 |
+
cv2.imencode(ext, img, encode_param)[1].tofile(img_path)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def show_img_by_dict(imgdicts):
|
| 154 |
+
for keyname in imgdicts.keys():
|
| 155 |
+
cv2.imshow(keyname, imgdicts[keyname])
|
| 156 |
+
cv2.waitKey(0)
|
| 157 |
+
|
| 158 |
+
def text_is_empty(text) -> bool:
|
| 159 |
+
if isinstance(text, str):
|
| 160 |
+
if text.strip() == '':
|
| 161 |
+
return True
|
| 162 |
+
if isinstance(text, list):
|
| 163 |
+
for t in text:
|
| 164 |
+
t_is_empty = text_is_empty(t)
|
| 165 |
+
if not t_is_empty:
|
| 166 |
+
return False
|
| 167 |
+
return True
|
| 168 |
+
elif text is None:
|
| 169 |
+
return True
|
| 170 |
+
|
| 171 |
+
def empty_func(*args, **kwargs):
|
| 172 |
+
return
|
| 173 |
+
|
| 174 |
+
def get_obj_from_str(string, reload=False):
|
| 175 |
+
module, cls = string.rsplit(".", 1)
|
| 176 |
+
if reload:
|
| 177 |
+
module_imp = importlib.import_module(module)
|
| 178 |
+
importlib.reload(module_imp)
|
| 179 |
+
return getattr(importlib.import_module(module, package=None), cls)
|
| 180 |
+
|
| 181 |
+
def get_module_from_str(module_str: str):
|
| 182 |
+
return importlib.import_module(module_str, package=None)
|
| 183 |
+
|
| 184 |
+
def build_funcmap(module_str: str, params_names: List[str], func_prefix: str = '', func_suffix: str = '', fallback_func: Callable = None, verbose: bool = True) -> Dict:
|
| 185 |
+
|
| 186 |
+
if fallback_func is None:
|
| 187 |
+
fallback_func = empty_func
|
| 188 |
+
|
| 189 |
+
module = get_module_from_str(module_str)
|
| 190 |
+
|
| 191 |
+
funcmap = {}
|
| 192 |
+
for param in params_names:
|
| 193 |
+
tgt_func = f'{func_prefix}{param}{func_suffix}'
|
| 194 |
+
try:
|
| 195 |
+
tgt_func = getattr(module, tgt_func)
|
| 196 |
+
except Exception as e:
|
| 197 |
+
if verbose:
|
| 198 |
+
print(f'failed to import {tgt_func} from {module_str}: {e}')
|
| 199 |
+
tgt_func = fallback_func
|
| 200 |
+
funcmap[param] = tgt_func
|
| 201 |
+
|
| 202 |
+
return funcmap
|
| 203 |
+
|
| 204 |
+
def _b64encode(x: bytes) -> str:
|
| 205 |
+
return base64.b64encode(x).decode("utf-8")
|
| 206 |
+
|
| 207 |
+
def img2b64(img):
|
| 208 |
+
"""
|
| 209 |
+
Convert a PIL image to a base64-encoded string.
|
| 210 |
+
"""
|
| 211 |
+
if isinstance(img, np.ndarray):
|
| 212 |
+
img = Image.fromarray(img)
|
| 213 |
+
buffered = io.BytesIO()
|
| 214 |
+
img.save(buffered, format='PNG')
|
| 215 |
+
return _b64encode(buffered.getvalue())
|
| 216 |
+
|
| 217 |
+
def save_encoded_image(b64_image: str, output_path: str):
|
| 218 |
+
with open(output_path, "wb") as image_file:
|
| 219 |
+
image_file.write(base64.b64decode(b64_image))
|
| 220 |
+
|
| 221 |
+
def submit_request(url, data, exist_on_exception=True, auth=None, wait_time = 5):
|
| 222 |
+
response = None
|
| 223 |
+
try:
|
| 224 |
+
while True:
|
| 225 |
+
try:
|
| 226 |
+
response = requests.post(url, data=data, auth=auth)
|
| 227 |
+
response.raise_for_status()
|
| 228 |
+
break
|
| 229 |
+
except Exception as e:
|
| 230 |
+
if wait_time > 0:
|
| 231 |
+
print(traceback.format_exc(), file=sys.stderr)
|
| 232 |
+
print(f'sleep {wait_time} sec...')
|
| 233 |
+
time.sleep(wait_time)
|
| 234 |
+
continue
|
| 235 |
+
else:
|
| 236 |
+
raise e
|
| 237 |
+
except Exception as e:
|
| 238 |
+
print(traceback.format_exc(), file=sys.stderr)
|
| 239 |
+
if response is not None:
|
| 240 |
+
print('response content: ' + response.text)
|
| 241 |
+
if exist_on_exception:
|
| 242 |
+
exit()
|
| 243 |
+
return response
|
utils/logger.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
import os.path as osp
|
| 5 |
+
from glob import glob
|
| 6 |
+
import termcolor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if os.name == "nt": # Windows
|
| 10 |
+
import colorama
|
| 11 |
+
colorama.init()
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
COLORS = {
|
| 15 |
+
"WARNING": "yellow",
|
| 16 |
+
"INFO": "white",
|
| 17 |
+
"DEBUG": "blue",
|
| 18 |
+
"CRITICAL": "red",
|
| 19 |
+
"ERROR": "red",
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ColoredFormatter(logging.Formatter):
|
| 24 |
+
def __init__(self, fmt, use_color=True):
|
| 25 |
+
logging.Formatter.__init__(self, fmt)
|
| 26 |
+
self.use_color = use_color
|
| 27 |
+
|
| 28 |
+
def format(self, record):
|
| 29 |
+
levelname = record.levelname
|
| 30 |
+
if self.use_color and levelname in COLORS:
|
| 31 |
+
|
| 32 |
+
def colored(text):
|
| 33 |
+
return termcolor.colored(
|
| 34 |
+
text,
|
| 35 |
+
color=COLORS[levelname],
|
| 36 |
+
attrs={"bold": True},
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
record.levelname2 = colored("{:<7}".format(record.levelname))
|
| 40 |
+
record.message2 = colored(record.getMessage())
|
| 41 |
+
|
| 42 |
+
asctime2 = datetime.datetime.fromtimestamp(record.created)
|
| 43 |
+
record.asctime2 = termcolor.colored(asctime2, color="green")
|
| 44 |
+
|
| 45 |
+
record.module2 = termcolor.colored(record.module, color="cyan")
|
| 46 |
+
record.funcName2 = termcolor.colored(record.funcName, color="cyan")
|
| 47 |
+
record.lineno2 = termcolor.colored(record.lineno, color="cyan")
|
| 48 |
+
return logging.Formatter.format(self, record)
|
| 49 |
+
|
| 50 |
+
FORMAT = (
|
| 51 |
+
"[%(levelname2)s] %(module2)s:%(funcName2)s:%(lineno2)s - %(message2)s"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
class ColoredLogger(logging.Logger):
|
| 55 |
+
|
| 56 |
+
def __init__(self, name):
|
| 57 |
+
logging.Logger.__init__(self, name, logging.INFO)
|
| 58 |
+
|
| 59 |
+
color_formatter = ColoredFormatter(FORMAT)
|
| 60 |
+
|
| 61 |
+
console = logging.StreamHandler()
|
| 62 |
+
console.setFormatter(color_formatter)
|
| 63 |
+
|
| 64 |
+
self.addHandler(console)
|
| 65 |
+
return
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def setup_logging(logfile_dir: str, max_num_logs=14):
|
| 69 |
+
|
| 70 |
+
if not osp.exists(logfile_dir):
|
| 71 |
+
os.makedirs(logfile_dir)
|
| 72 |
+
else:
|
| 73 |
+
old_logs = glob(osp.join(logfile_dir, '*.log'))
|
| 74 |
+
old_logs.sort()
|
| 75 |
+
n_log = len(old_logs)
|
| 76 |
+
if n_log >= max_num_logs:
|
| 77 |
+
to_remove = n_log - max_num_logs + 1
|
| 78 |
+
try:
|
| 79 |
+
for ii in range(to_remove):
|
| 80 |
+
os.remove(old_logs[ii])
|
| 81 |
+
except Exception as e:
|
| 82 |
+
logger.error(e)
|
| 83 |
+
|
| 84 |
+
logfilename = datetime.datetime.now().strftime('_%Y_%m_%d-%H_%M_%S.log')
|
| 85 |
+
logfilep = osp.join(logfile_dir, logfilename)
|
| 86 |
+
fh = logging.FileHandler(logfilep, mode='w', encoding='utf-8')
|
| 87 |
+
fh.setFormatter(
|
| 88 |
+
logging.Formatter(
|
| 89 |
+
("[%(levelname)s] %(module)s:%(funcName)s:%(lineno)s - %(message)s")
|
| 90 |
+
)
|
| 91 |
+
)
|
| 92 |
+
fh.setLevel(logging.DEBUG)
|
| 93 |
+
logger.addHandler(fh)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
logging.setLoggerClass(ColoredLogger)
|
| 97 |
+
logger = logging.getLogger('BallonTranslator')
|
| 98 |
+
logger.setLevel(logging.DEBUG)
|
| 99 |
+
logger.propagate = False
|
utils/message.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import traceback
|
| 2 |
+
from typing import Callable, List, Dict
|
| 3 |
+
|
| 4 |
+
from . import shared
|
| 5 |
+
from .logger import logger as LOGGER
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_error_dialog(exception: Exception, error_msg: str = None, exception_type: str = None):
|
| 9 |
+
'''
|
| 10 |
+
Popup a error dialog in main thread
|
| 11 |
+
Args:
|
| 12 |
+
error_msg: Description text prepend before str(exception)
|
| 13 |
+
exception_type: Specify it to avoid errors dialog of the same type popup repeatedly
|
| 14 |
+
'''
|
| 15 |
+
|
| 16 |
+
detail_traceback = traceback.format_exc()
|
| 17 |
+
|
| 18 |
+
if exception_type is None:
|
| 19 |
+
exception_type = ''
|
| 20 |
+
|
| 21 |
+
exception_type_empty = exception_type == ''
|
| 22 |
+
show_exception = exception_type_empty or exception_type not in shared.showed_exception
|
| 23 |
+
|
| 24 |
+
if show_exception:
|
| 25 |
+
if error_msg is None:
|
| 26 |
+
error_msg = str(exception)
|
| 27 |
+
else:
|
| 28 |
+
error_msg = str(exception) + '\n' + error_msg
|
| 29 |
+
LOGGER.error(error_msg + '\n')
|
| 30 |
+
LOGGER.error(detail_traceback)
|
| 31 |
+
|
| 32 |
+
if not shared.HEADLESS:
|
| 33 |
+
shared.create_errdialog_in_mainthread(error_msg, detail_traceback, exception_type)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def create_info_dialog(info_msg, btn_type=None, modal: bool = False, frame_less: bool = False, signal_slot_map_list: List[Dict] = None):
|
| 37 |
+
'''
|
| 38 |
+
Popup a info dialog in main thread
|
| 39 |
+
'''
|
| 40 |
+
LOGGER.info(info_msg)
|
| 41 |
+
if not shared.HEADLESS:
|
| 42 |
+
shared.create_infodialog_in_mainthread({'info_msg': info_msg, 'btn_type': btn_type, 'modal': modal, 'frame_less': frame_less, 'signal_slot_map_list': signal_slot_map_list})
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def connect_once(signal, exec_func: Callable):
|
| 46 |
+
'''
|
| 47 |
+
signal.emit will only trigger exec_func once
|
| 48 |
+
'''
|
| 49 |
+
|
| 50 |
+
def _disconnect_after_called(*func_args, **func_kwargs):
|
| 51 |
+
|
| 52 |
+
def _try_disconnect():
|
| 53 |
+
try:
|
| 54 |
+
signal.disconnect(connect_func)
|
| 55 |
+
except:
|
| 56 |
+
print('Failed to disconnect')
|
| 57 |
+
print(traceback.format_exc())
|
| 58 |
+
|
| 59 |
+
try:
|
| 60 |
+
exec_func(*func_args, **func_kwargs)
|
| 61 |
+
except Exception as e:
|
| 62 |
+
_try_disconnect()
|
| 63 |
+
raise e
|
| 64 |
+
_try_disconnect()
|
| 65 |
+
|
| 66 |
+
connect_func = _disconnect_after_called
|
| 67 |
+
signal.connect(_disconnect_after_called)
|
utils/package.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# copied from https://github.com/HansBug/hbutils/blob/main/hbutils/system/python/package.py
|
| 2 |
+
# to replace the deprecated pkg_resources
|
| 3 |
+
|
| 4 |
+
import functools
|
| 5 |
+
import itertools
|
| 6 |
+
import os
|
| 7 |
+
import pathlib
|
| 8 |
+
import subprocess
|
| 9 |
+
import sys
|
| 10 |
+
from typing import List, Optional
|
| 11 |
+
|
| 12 |
+
from packaging.requirements import Requirement
|
| 13 |
+
from packaging.utils import canonicalize_name
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
import importlib.metadata as importlib_metadata
|
| 17 |
+
except (ModuleNotFoundError, ImportError):
|
| 18 |
+
import importlib_metadata
|
| 19 |
+
from packaging.version import Version
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def package_version(name: str) -> Optional[Version]:
|
| 23 |
+
"""
|
| 24 |
+
Overview:
|
| 25 |
+
Get version of package with given ``name``.
|
| 26 |
+
|
| 27 |
+
:param name: Name of the package, case is not sensitive.
|
| 28 |
+
:return: A :class:`packing.version.Version` object. If the package is not installed, return ``None``.
|
| 29 |
+
|
| 30 |
+
Examples::
|
| 31 |
+
>>> from hbutils.system import package_version
|
| 32 |
+
>>>
|
| 33 |
+
>>> package_version('pip')
|
| 34 |
+
<Version('21.3.1')>
|
| 35 |
+
>>> package_version('setuptools')
|
| 36 |
+
<Version('59.6.0')>
|
| 37 |
+
>>> package_version('not_a_package')
|
| 38 |
+
None
|
| 39 |
+
"""
|
| 40 |
+
try:
|
| 41 |
+
return Version(importlib_metadata.distribution(canonicalize_name(name)).version)
|
| 42 |
+
except importlib_metadata.PackageNotFoundError:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _nonblank(text):
|
| 47 |
+
return text and not text.startswith('#')
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@functools.singledispatch
|
| 51 |
+
def yield_lines(iterable):
|
| 52 |
+
r"""
|
| 53 |
+
Based on https://github.com/jaraco/jaraco.text/blob/main/jaraco/text/__init__.py#L537 .
|
| 54 |
+
Yield valid lines of a string or iterable.
|
| 55 |
+
>>> list(yield_lines(''))
|
| 56 |
+
[]
|
| 57 |
+
>>> list(yield_lines(['foo', 'bar']))
|
| 58 |
+
['foo', 'bar']
|
| 59 |
+
>>> list(yield_lines('foo\nbar'))
|
| 60 |
+
['foo', 'bar']
|
| 61 |
+
>>> list(yield_lines('\nfoo\n#bar\nbaz #comment'))
|
| 62 |
+
['foo', 'baz #comment']
|
| 63 |
+
>>> list(yield_lines(['foo\nbar', 'baz', 'bing\n\n\n']))
|
| 64 |
+
['foo', 'bar', 'baz', 'bing']
|
| 65 |
+
"""
|
| 66 |
+
return itertools.chain.from_iterable(map(yield_lines, iterable))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@yield_lines.register(str)
|
| 70 |
+
def _(text):
|
| 71 |
+
return filter(_nonblank, map(str.strip, text.splitlines()))
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def drop_comment(line):
|
| 75 |
+
"""
|
| 76 |
+
Based on https://github.com/jaraco/jaraco.text/blob/main/jaraco/text/__init__.py#L560 .
|
| 77 |
+
Drop comments.
|
| 78 |
+
>>> drop_comment('foo # bar')
|
| 79 |
+
'foo'
|
| 80 |
+
A hash without a space may be in a URL.
|
| 81 |
+
>>> drop_comment('https://example.com/foo#bar')
|
| 82 |
+
'https://example.com/foo#bar'
|
| 83 |
+
"""
|
| 84 |
+
return line.partition(' #')[0]
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def join_continuation(lines):
|
| 88 |
+
r"""
|
| 89 |
+
Based on https://github.com/jaraco/jaraco.text/blob/main/jaraco/text/__init__.py#L575 .
|
| 90 |
+
Join lines continued by a trailing backslash.
|
| 91 |
+
>>> list(join_continuation(['foo \\', 'bar', 'baz']))
|
| 92 |
+
['foobar', 'baz']
|
| 93 |
+
>>> list(join_continuation(['foo \\', 'bar', 'baz']))
|
| 94 |
+
['foobar', 'baz']
|
| 95 |
+
>>> list(join_continuation(['foo \\', 'bar \\', 'baz']))
|
| 96 |
+
['foobarbaz']
|
| 97 |
+
Not sure why, but...
|
| 98 |
+
The character preceding the backslash is also elided.
|
| 99 |
+
>>> list(join_continuation(['goo\\', 'dly']))
|
| 100 |
+
['godly']
|
| 101 |
+
A terrible idea, but...
|
| 102 |
+
If no line is available to continue, suppress the lines.
|
| 103 |
+
>>> list(join_continuation(['foo', 'bar\\', 'baz\\']))
|
| 104 |
+
['foo']
|
| 105 |
+
"""
|
| 106 |
+
lines = iter(lines)
|
| 107 |
+
for item in lines:
|
| 108 |
+
while item.endswith('\\'):
|
| 109 |
+
try: # pragma: no cover
|
| 110 |
+
item = item[:-2].strip() + next(lines)
|
| 111 |
+
except StopIteration:
|
| 112 |
+
return
|
| 113 |
+
yield item
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def load_req_file(requirements_file: str) -> List[str]:
|
| 117 |
+
"""
|
| 118 |
+
Overview:
|
| 119 |
+
Load requirements items from a ``requirements.txt`` file.
|
| 120 |
+
|
| 121 |
+
:param requirements_file: Requirements file.
|
| 122 |
+
:return requirements: List of requirements.
|
| 123 |
+
|
| 124 |
+
Examples::
|
| 125 |
+
>>> from hbutils.system import load_req_file
|
| 126 |
+
>>> load_req_file('requirements.txt')
|
| 127 |
+
['packaging>=21.3', 'setuptools>=50.0']
|
| 128 |
+
"""
|
| 129 |
+
with pathlib.Path(requirements_file).open() as reqfile:
|
| 130 |
+
return list(map(
|
| 131 |
+
lambda x: str(Requirement(x)),
|
| 132 |
+
join_continuation(map(drop_comment, yield_lines(reqfile)))
|
| 133 |
+
))
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def pip(*args, silent: bool = False):
|
| 137 |
+
"""
|
| 138 |
+
Overview:
|
| 139 |
+
Run pip command with code.
|
| 140 |
+
|
| 141 |
+
:param args: Command line arguments for ``pip`` command.
|
| 142 |
+
:param silent: Do not print anything. Default is false, which means print the output to ``sys.stdout`` \
|
| 143 |
+
and ``sys.stderr``.
|
| 144 |
+
|
| 145 |
+
Examples::
|
| 146 |
+
>>> from hbutils.system import pip
|
| 147 |
+
>>> pip('-V')
|
| 148 |
+
pip 22.3.1 from /home/user/myproject/venv/lib/python3.7/site-packages/pip (python 3.7)
|
| 149 |
+
>>> pip('-V', silent=True) # nothing will be printed
|
| 150 |
+
"""
|
| 151 |
+
process = subprocess.run(
|
| 152 |
+
[sys.executable, '-m', 'pip', *args],
|
| 153 |
+
stdin=sys.stdin if not silent else None,
|
| 154 |
+
stdout=sys.stdout if not silent else subprocess.PIPE,
|
| 155 |
+
stderr=sys.stderr if not silent else subprocess.PIPE,
|
| 156 |
+
)
|
| 157 |
+
assert not process.returncode, f'Error when calling {process.args!r}{os.linesep}' \
|
| 158 |
+
f'Error Code - {process.returncode}{os.linesep}' \
|
| 159 |
+
f'Stdout:{os.linesep}' \
|
| 160 |
+
f'{process.stdout.decode()}{os.linesep}' \
|
| 161 |
+
f'{os.linesep}' \
|
| 162 |
+
f'Stderr:{os.linesep}' \
|
| 163 |
+
f'{process.stderr.decode()}{os.linesep}'
|
| 164 |
+
process.check_returncode()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _yield_reqs_to_install(req: Requirement, current_extra: str = ''):
|
| 168 |
+
if req.marker and not req.marker.evaluate({'extra': current_extra}):
|
| 169 |
+
return
|
| 170 |
+
|
| 171 |
+
try:
|
| 172 |
+
version = importlib_metadata.distribution(req.name).version
|
| 173 |
+
except importlib_metadata.PackageNotFoundError: # req not installed
|
| 174 |
+
yield req
|
| 175 |
+
else:
|
| 176 |
+
if req.specifier.contains(version, prereleases=True):
|
| 177 |
+
for child_req in (importlib_metadata.metadata(req.name).get_all('Requires-Dist') or []):
|
| 178 |
+
child_req_obj = Requirement(child_req)
|
| 179 |
+
|
| 180 |
+
need_check, ext = False, None
|
| 181 |
+
for extra in req.extras:
|
| 182 |
+
if child_req_obj.marker and child_req_obj.marker.evaluate({'extra': extra}):
|
| 183 |
+
need_check = True
|
| 184 |
+
ext = extra
|
| 185 |
+
break
|
| 186 |
+
|
| 187 |
+
if need_check: # check for extra reqs
|
| 188 |
+
yield from _yield_reqs_to_install(child_req_obj, ext)
|
| 189 |
+
|
| 190 |
+
else: # main version not match
|
| 191 |
+
yield req
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _check_req(req: Requirement):
|
| 195 |
+
return not bool(list(itertools.islice(_yield_reqs_to_install(req), 1)))
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def check_reqs(reqs: List[str]) -> bool:
|
| 199 |
+
"""
|
| 200 |
+
Overview:
|
| 201 |
+
Check if the given requirements are all satisfied.
|
| 202 |
+
|
| 203 |
+
:param reqs: List of requirements.
|
| 204 |
+
:return satisfied: All the requirements in ``reqs`` satisfied or not.
|
| 205 |
+
|
| 206 |
+
Examples::
|
| 207 |
+
>>> from hbutils.system import check_reqs
|
| 208 |
+
>>> check_reqs(['pip>=20.0'])
|
| 209 |
+
True
|
| 210 |
+
>>> check_reqs(['pip~=19.2'])
|
| 211 |
+
False
|
| 212 |
+
>>> check_reqs(['pip>=20.0', 'setuptools>=50.0'])
|
| 213 |
+
True
|
| 214 |
+
|
| 215 |
+
.. note::
|
| 216 |
+
If a requirement's marker is not satisfied in this environment,
|
| 217 |
+
**it will be ignored** instead of return ``False``.
|
| 218 |
+
"""
|
| 219 |
+
return all(map(lambda x: _check_req(Requirement(x)), reqs))
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def check_req_file(requirements_file: str) -> bool:
|
| 223 |
+
"""
|
| 224 |
+
Overview:
|
| 225 |
+
Check if the requirements in the given ``requirements_file`` is satisfied.
|
| 226 |
+
|
| 227 |
+
:param requirements_file: Requirements file, such as ``requirements.txt``.
|
| 228 |
+
:return satisfied: All the requirements in ``requirements_file`` satisfied or not.
|
| 229 |
+
|
| 230 |
+
Examples::
|
| 231 |
+
>>> from hbutils.system import check_req_file
|
| 232 |
+
>>>
|
| 233 |
+
>>> check_req_file('requirements.txt')
|
| 234 |
+
True
|
| 235 |
+
>>> check_req_file('requirements-test.txt')
|
| 236 |
+
True
|
| 237 |
+
"""
|
| 238 |
+
return check_reqs(load_req_file(requirements_file))
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def pip_install(reqs: List[str], silent: bool = False, force: bool = False, user: bool = False):
|
| 242 |
+
"""
|
| 243 |
+
Overview:
|
| 244 |
+
Pip install requirements with code.
|
| 245 |
+
Similar to ``pip install req1 req2 ...``.
|
| 246 |
+
|
| 247 |
+
:param reqs: Requirement items to install.
|
| 248 |
+
:param silent: Do not print anything. Default is ``False``.
|
| 249 |
+
:param force: Force execute the ``pip install`` command. Default is ``False`` which means the requirements \
|
| 250 |
+
will be checked before installation, and the installation will be only executed when \
|
| 251 |
+
some requirements not installed.
|
| 252 |
+
:param user: User mode, represents ``--user`` option in ``pip``.
|
| 253 |
+
|
| 254 |
+
Examples::
|
| 255 |
+
>>> from hbutils.system import pip_install
|
| 256 |
+
>>> pip_install(['scikit-learn']) # not installed
|
| 257 |
+
Looking in indexes: https://xxx/simple
|
| 258 |
+
Collecting scikit-learn
|
| 259 |
+
Using cached https://xxx/scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (24.8 MB)
|
| 260 |
+
Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn
|
| 261 |
+
Successfully installed joblib-1.2.0 scikit-learn-1.0.2 scipy-1.7.3 threadpoolctl-3.1.0
|
| 262 |
+
>>> pip_install(['numpy>=1.10.0']) # installed
|
| 263 |
+
>>> pip_install(['numpy>=1.10.0'], force=True) # force execute
|
| 264 |
+
Looking in indexes: https://xxx/simple
|
| 265 |
+
Requirement already satisfied: numpy>=1.10.0 in ./venv/lib/python3.7/site-packages (1.21.6)
|
| 266 |
+
"""
|
| 267 |
+
if force or not check_reqs(reqs):
|
| 268 |
+
pip('install', *(('--user',) if user else ()), *reqs, silent=silent)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def pip_install_req_file(requirements_file: str, silent: bool = False, force: bool = False, user: bool = False):
|
| 272 |
+
"""
|
| 273 |
+
Overview:
|
| 274 |
+
Pip install requirements from file with code.
|
| 275 |
+
Similar to ``pip install -r requirements.txt``.
|
| 276 |
+
|
| 277 |
+
:param requirements_file: Requirements file, such as ``requirements.txt``.
|
| 278 |
+
:param silent: Do not print anything. Default is ``False``.
|
| 279 |
+
:param force: Force execute the ``pip install`` command. Default is ``False`` which means the requirements \
|
| 280 |
+
will be checked before installation, and the installation will be only executed when \
|
| 281 |
+
some requirements not installed.
|
| 282 |
+
:param user: User mode, represents ``--user`` option in ``pip``.
|
| 283 |
+
|
| 284 |
+
Examples::
|
| 285 |
+
>>> from hbutils.system import pip_install_req_file
|
| 286 |
+
>>> pip_install_req_file('requirements.txt') # pip install -r requirements.txt
|
| 287 |
+
"""
|
| 288 |
+
if force or not check_req_file(requirements_file):
|
| 289 |
+
pip('install', *(('--user',) if user else ()), '-r', requirements_file, silent=silent)
|
utils/proj_imgtrans.py
ADDED
|
@@ -0,0 +1,720 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json, shutil, re, docx, docx2txt, piexif, cv2
|
| 2 |
+
from docx.shared import Inches
|
| 3 |
+
from docx import Document
|
| 4 |
+
import piexif.helper
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os.path as osp
|
| 7 |
+
from typing import Tuple, Union, List, Dict
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from utils.watermark_utils import apply_watermark_to_pil_image
|
| 11 |
+
from .logger import logger as LOGGER
|
| 12 |
+
from .io_utils import find_all_imgs, imread, imwrite, NumpyEncoder
|
| 13 |
+
from .textblock import TextBlock, FontFormat
|
| 14 |
+
from .config import pcfg
|
| 15 |
+
from . import shared
|
| 16 |
+
from .exceptions import ImgnameNotInProjectException, ProjectLoadFailureException, ProjectDirNotExistException, ProjectNotSupportedException
|
| 17 |
+
|
| 18 |
+
class ImageLoadException(Exception):
|
| 19 |
+
def __init__(self, img_path):
|
| 20 |
+
super().__init__(f"Failed to load image: {img_path}")
|
| 21 |
+
self.img_path = img_path
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_last_modified_file(file_prefix, exts, ext_fallback=None):
|
| 25 |
+
'''
|
| 26 |
+
get last modified file from files sharing same prefix
|
| 27 |
+
'''
|
| 28 |
+
latest_time = -1
|
| 29 |
+
latest_f = None
|
| 30 |
+
for ext in exts:
|
| 31 |
+
tmp_p = file_prefix + ext
|
| 32 |
+
if osp.exists(tmp_p) and osp.getmtime(tmp_p) > latest_time:
|
| 33 |
+
latest_time = osp.getmtime(tmp_p)
|
| 34 |
+
latest_f = tmp_p
|
| 35 |
+
if latest_f is None:
|
| 36 |
+
if ext_fallback is not None:
|
| 37 |
+
latest_f = file_prefix + ext_fallback
|
| 38 |
+
else:
|
| 39 |
+
latest_f = file_prefix + exts[0]
|
| 40 |
+
return latest_f
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def write_jpg_metadata(imgpath: str, metadata="a metadata"):
|
| 44 |
+
exif_dict = {"Exif":{piexif.ExifIFD.UserComment: piexif.helper.UserComment.dump(metadata, encoding='unicode')}}
|
| 45 |
+
exif_bytes = piexif.dump(exif_dict)
|
| 46 |
+
piexif.insert(exif_bytes, imgpath)
|
| 47 |
+
|
| 48 |
+
def read_jpg_metadata(imgpath: str):
|
| 49 |
+
exif_dict = piexif.load(imgpath)
|
| 50 |
+
user_comment = piexif.helper.UserComment.load(exif_dict["Exif"][piexif.ExifIFD.UserComment])
|
| 51 |
+
bubdict = json.loads(user_comment)
|
| 52 |
+
return bubdict
|
| 53 |
+
|
| 54 |
+
page_start_pattern = re.compile(r'^###\s+', re.MULTILINE)
|
| 55 |
+
text_blkid_start_pattern = re.compile(r'^\d+\.', re.MULTILINE)
|
| 56 |
+
|
| 57 |
+
def parse_txt_translation(file_path: str):
|
| 58 |
+
with open(file_path, 'r', encoding='utf8') as f:
|
| 59 |
+
content = f.read()
|
| 60 |
+
page_start = None
|
| 61 |
+
page_list = []
|
| 62 |
+
for matched in page_start_pattern.finditer(content):
|
| 63 |
+
start, end = matched.span()
|
| 64 |
+
if page_start is not None:
|
| 65 |
+
page_list.append({'page_content': content[page_start: start]})
|
| 66 |
+
page_start = start
|
| 67 |
+
if page_start is not None:
|
| 68 |
+
page_list.append({'page_content': content[page_start:]})
|
| 69 |
+
|
| 70 |
+
for page_dict in page_list:
|
| 71 |
+
page_content = page_dict['page_content']
|
| 72 |
+
page_dict['page_name'] = page_start_pattern.sub('', page_content.split('\n')[0]).strip()
|
| 73 |
+
blkid_start = blkid_end = None
|
| 74 |
+
blk_list = []
|
| 75 |
+
for matched in text_blkid_start_pattern.finditer(page_content):
|
| 76 |
+
start, end = matched.span()
|
| 77 |
+
if blkid_start is not None:
|
| 78 |
+
blk_list.append(page_content[blkid_end: start].strip())
|
| 79 |
+
blkid_start = start
|
| 80 |
+
blkid_end = end
|
| 81 |
+
if blkid_start is not None:
|
| 82 |
+
blk_list.append(page_content[blkid_end:].strip())
|
| 83 |
+
page_dict['blk_list'] = blk_list
|
| 84 |
+
|
| 85 |
+
return page_list
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class TextBlkEncoder(NumpyEncoder):
|
| 89 |
+
def default(self, obj):
|
| 90 |
+
if isinstance(obj, TextBlock):
|
| 91 |
+
return obj.to_dict()
|
| 92 |
+
elif isinstance(obj, FontFormat):
|
| 93 |
+
return vars(obj)
|
| 94 |
+
return NumpyEncoder.default(self, obj)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class ProjImgTrans:
|
| 98 |
+
|
| 99 |
+
def __init__(self, directory: str = None):
|
| 100 |
+
self.type = 'imgtrans'
|
| 101 |
+
self.directory: str = None
|
| 102 |
+
self.pages: Dict[str, List[TextBlock]] = {}
|
| 103 |
+
self._pagename2idx = {}
|
| 104 |
+
self._idx2pagename = {}
|
| 105 |
+
|
| 106 |
+
self._fuzzy_inpainted_list = None
|
| 107 |
+
|
| 108 |
+
self.not_found_pages: Dict[str, List[TextBlock]] = {}
|
| 109 |
+
self.new_pages: List[str] = []
|
| 110 |
+
self.proj_path: str = None
|
| 111 |
+
|
| 112 |
+
self.current_img: str = None
|
| 113 |
+
self.img_array: np.ndarray = None
|
| 114 |
+
self.mask_array: np.ndarray = None
|
| 115 |
+
self.inpainted_array: np.ndarray = None
|
| 116 |
+
|
| 117 |
+
# Watermark settings
|
| 118 |
+
self.enable_watermark = False
|
| 119 |
+
self.watermark_path = ""
|
| 120 |
+
self.watermark_opacity = 0.7
|
| 121 |
+
|
| 122 |
+
if directory is not None:
|
| 123 |
+
self.load(directory)
|
| 124 |
+
|
| 125 |
+
def idx2pagename(self, idx: int) -> str:
|
| 126 |
+
return self._idx2pagename[idx]
|
| 127 |
+
|
| 128 |
+
def pagename2idx(self, pagename: str) -> int:
|
| 129 |
+
if pagename in self.pages:
|
| 130 |
+
return self._pagename2idx[pagename]
|
| 131 |
+
return -1
|
| 132 |
+
|
| 133 |
+
def proj_name(self) -> str:
|
| 134 |
+
return self.type+'_'+osp.basename(self.directory)
|
| 135 |
+
|
| 136 |
+
def load(self, directory: str, json_path: str = None) -> bool:
|
| 137 |
+
self.directory = directory
|
| 138 |
+
if json_path is None:
|
| 139 |
+
self.proj_path = osp.join(self.directory, self.proj_name() + '.json')
|
| 140 |
+
else:
|
| 141 |
+
self.proj_path = json_path
|
| 142 |
+
new_proj = False
|
| 143 |
+
if not osp.exists(self.proj_path):
|
| 144 |
+
new_proj = True
|
| 145 |
+
self.new_project()
|
| 146 |
+
else:
|
| 147 |
+
try:
|
| 148 |
+
with open(self.proj_path, 'r', encoding='utf8') as f:
|
| 149 |
+
proj_dict = json.loads(f.read())
|
| 150 |
+
except Exception as e:
|
| 151 |
+
raise ProjectLoadFailureException(e)
|
| 152 |
+
self.load_from_dict(proj_dict)
|
| 153 |
+
if not osp.exists(self.inpainted_dir()):
|
| 154 |
+
os.makedirs(self.inpainted_dir())
|
| 155 |
+
if not osp.exists(self.mask_dir()):
|
| 156 |
+
os.makedirs(self.mask_dir())
|
| 157 |
+
|
| 158 |
+
# Fix: use self instead of proj and check proj_dict existence
|
| 159 |
+
if 'enable_watermark' in locals() and 'proj_dict' in locals() and 'enable_watermark' in proj_dict:
|
| 160 |
+
self.enable_watermark = proj_dict['enable_watermark']
|
| 161 |
+
if 'proj_dict' in locals() and 'watermark_path' in proj_dict:
|
| 162 |
+
self.watermark_path = proj_dict['watermark_path']
|
| 163 |
+
if 'proj_dict' in locals() and 'watermark_opacity' in proj_dict:
|
| 164 |
+
self.watermark_opacity = proj_dict['watermark_opacity']
|
| 165 |
+
|
| 166 |
+
return new_proj
|
| 167 |
+
|
| 168 |
+
def mask_dir(self):
|
| 169 |
+
return osp.join(self.directory, 'mask')
|
| 170 |
+
|
| 171 |
+
def inpainted_dir(self):
|
| 172 |
+
return osp.join(self.directory, 'inpainted')
|
| 173 |
+
|
| 174 |
+
def result_dir(self):
|
| 175 |
+
return osp.join(self.directory, 'result')
|
| 176 |
+
|
| 177 |
+
def load_from_dict(self, proj_dict: dict):
|
| 178 |
+
self.set_current_img(None)
|
| 179 |
+
try:
|
| 180 |
+
self.pages = {}
|
| 181 |
+
self._pagename2idx = {}
|
| 182 |
+
self._idx2pagename = {}
|
| 183 |
+
self.not_found_pages = {}
|
| 184 |
+
page_dict = proj_dict['pages']
|
| 185 |
+
not_found_pages = list(page_dict.keys())
|
| 186 |
+
found_pages = find_all_imgs(img_dir=self.directory, abs_path=False, sort=True)
|
| 187 |
+
for ii, imname in enumerate(found_pages):
|
| 188 |
+
if imname in page_dict:
|
| 189 |
+
self.pages[imname] = [TextBlock(**blk_dict) for blk_dict in page_dict[imname]]
|
| 190 |
+
not_found_pages.remove(imname)
|
| 191 |
+
else:
|
| 192 |
+
self.pages[imname] = []
|
| 193 |
+
self.new_pages.append(imname)
|
| 194 |
+
self._pagename2idx[imname] = ii
|
| 195 |
+
self._idx2pagename[ii] = imname
|
| 196 |
+
for imname in not_found_pages:
|
| 197 |
+
self.not_found_pages[imname] = [TextBlock(**blk_dict) for blk_dict in page_dict[imname]]
|
| 198 |
+
except Exception as e:
|
| 199 |
+
raise ProjectNotSupportedException(e)
|
| 200 |
+
set_img_failed = False
|
| 201 |
+
if 'current_img' in proj_dict:
|
| 202 |
+
current_img = proj_dict['current_img']
|
| 203 |
+
try:
|
| 204 |
+
self.set_current_img(current_img)
|
| 205 |
+
except (ImgnameNotInProjectException, RuntimeError) as e:
|
| 206 |
+
LOGGER.error(f"Failed to set current image {current_img}: {e}")
|
| 207 |
+
set_img_failed = True
|
| 208 |
+
else:
|
| 209 |
+
set_img_failed = True
|
| 210 |
+
LOGGER.warning(f'{current_img} not found.')
|
| 211 |
+
if set_img_failed:
|
| 212 |
+
if len(self.pages) > 0:
|
| 213 |
+
try:
|
| 214 |
+
self.set_current_img_byidx(0)
|
| 215 |
+
except RuntimeError as e:
|
| 216 |
+
LOGGER.error(f"Failed to set current image by index 0: {e}")
|
| 217 |
+
|
| 218 |
+
def load_translation_from_txt(self, file_path: str):
|
| 219 |
+
page_list = parse_txt_translation(file_path)
|
| 220 |
+
missing_pages = []
|
| 221 |
+
unmatched_pages = []
|
| 222 |
+
unexpected_pages = []
|
| 223 |
+
matched_pages = []
|
| 224 |
+
for page_dict in page_list:
|
| 225 |
+
page_name = page_dict['page_name']
|
| 226 |
+
if page_name in self.pages:
|
| 227 |
+
matched_pages.append(page_name)
|
| 228 |
+
else:
|
| 229 |
+
unexpected_pages.append(page_name)
|
| 230 |
+
continue
|
| 231 |
+
blklist = self.pages[page_name]
|
| 232 |
+
n_blk = len(blklist)
|
| 233 |
+
src_blk_list = page_dict['blk_list']
|
| 234 |
+
n_src_blk = len(src_blk_list)
|
| 235 |
+
if n_src_blk != n_blk:
|
| 236 |
+
LOGGER.warning(f'Unmatched text blocks in {page_name}, number of text blocks in this page vs source file: {n_blk}-{n_src_blk}')
|
| 237 |
+
unmatched_pages.append(page_name)
|
| 238 |
+
for blkid in range(min(n_blk, n_src_blk)):
|
| 239 |
+
blk = blklist[blkid]
|
| 240 |
+
blk.rich_text = ''
|
| 241 |
+
blk.translation = src_blk_list[blkid]
|
| 242 |
+
|
| 243 |
+
matched_pages = set(matched_pages)
|
| 244 |
+
if len(matched_pages) != self.num_pages:
|
| 245 |
+
for page_name in self.pages:
|
| 246 |
+
if page_name not in matched_pages:
|
| 247 |
+
missing_pages.append(page_name)
|
| 248 |
+
|
| 249 |
+
all_matched = len(missing_pages) == 0 and len(unmatched_pages) == 0 and len(unexpected_pages) == 0
|
| 250 |
+
return all_matched, {'missing_pages': missing_pages, 'unmatched_pages': unmatched_pages, 'unexpected_pages': unexpected_pages, 'matched_pages': matched_pages}
|
| 251 |
+
|
| 252 |
+
def load_from_json(self, json_path: str):
|
| 253 |
+
old_dir = self.directory
|
| 254 |
+
directory = osp.dirname(json_path)
|
| 255 |
+
try:
|
| 256 |
+
self.load(directory, json_path=json_path)
|
| 257 |
+
except Exception as e:
|
| 258 |
+
self.load(old_dir)
|
| 259 |
+
raise ProjectLoadFailureException(e)
|
| 260 |
+
|
| 261 |
+
def set_current_img(self, imgname: str):
|
| 262 |
+
if imgname is not None:
|
| 263 |
+
if imgname not in self.pages:
|
| 264 |
+
raise ImgnameNotInProjectException
|
| 265 |
+
self.current_img = imgname
|
| 266 |
+
img_path = self.current_img_path()
|
| 267 |
+
mask_path = self.get_mask_path(get_last_modified=True)
|
| 268 |
+
self.img_array = imread(img_path)
|
| 269 |
+
if self.img_array is None:
|
| 270 |
+
raise RuntimeError(f"Failed to load image: {img_path}")
|
| 271 |
+
im_h, im_w = self.img_array.shape[:2]
|
| 272 |
+
if osp.exists(mask_path):
|
| 273 |
+
self.mask_array = imread(mask_path, cv2.IMREAD_GRAYSCALE)
|
| 274 |
+
else:
|
| 275 |
+
self.mask_array = np.zeros((im_h, im_w), dtype=np.uint8)
|
| 276 |
+
self.inpainted_array = self.load_inpainted_by_imgname(imgname)
|
| 277 |
+
if self.inpainted_array is None:
|
| 278 |
+
self.inpainted_array = np.copy(self.img_array)
|
| 279 |
+
else:
|
| 280 |
+
self.current_img = None
|
| 281 |
+
self.img_array = None
|
| 282 |
+
self.mask_array = None
|
| 283 |
+
self.inpainted_array = None
|
| 284 |
+
|
| 285 |
+
def set_current_img_byidx(self, idx: int):
|
| 286 |
+
num_pages = self.num_pages
|
| 287 |
+
if idx < 0:
|
| 288 |
+
idx = idx + self.num_pages
|
| 289 |
+
if idx < 0 or idx > num_pages - 1:
|
| 290 |
+
self.set_current_img(None)
|
| 291 |
+
else:
|
| 292 |
+
self.set_current_img(self.idx2pagename(idx))
|
| 293 |
+
|
| 294 |
+
def get_blklist_byidx(self, idx: int) -> List[TextBlock]:
|
| 295 |
+
return self.pages[self.idx2pagename(idx)]
|
| 296 |
+
|
| 297 |
+
@property
|
| 298 |
+
def num_pages(self) -> int:
|
| 299 |
+
return len(self.pages)
|
| 300 |
+
|
| 301 |
+
@property
|
| 302 |
+
def current_idx(self) -> int:
|
| 303 |
+
return self.pagename2idx(self.current_img)
|
| 304 |
+
|
| 305 |
+
def new_project(self):
|
| 306 |
+
if not osp.exists(self.directory):
|
| 307 |
+
raise ProjectDirNotExistException
|
| 308 |
+
self.set_current_img(None)
|
| 309 |
+
imglist = find_all_imgs(self.directory, abs_path=False, sort=True)
|
| 310 |
+
self.pages = {}
|
| 311 |
+
self._pagename2idx = {}
|
| 312 |
+
self._idx2pagename = {}
|
| 313 |
+
for ii, imgname in enumerate(imglist):
|
| 314 |
+
self.pages[imgname] = []
|
| 315 |
+
self._pagename2idx[imgname] = ii
|
| 316 |
+
self._idx2pagename[ii] = imgname
|
| 317 |
+
self.set_current_img_byidx(0)
|
| 318 |
+
self.save()
|
| 319 |
+
|
| 320 |
+
def save(self):
|
| 321 |
+
if not osp.exists(self.directory):
|
| 322 |
+
raise ProjectDirNotExistException
|
| 323 |
+
with open(self.proj_path, "w", encoding="utf-8") as f:
|
| 324 |
+
f.write(json.dumps(self.to_dict(), ensure_ascii=False, cls=TextBlkEncoder))
|
| 325 |
+
LOGGER.debug(f'project saved to {self.proj_path}')
|
| 326 |
+
|
| 327 |
+
def to_dict(self) -> Dict:
|
| 328 |
+
pages = self.pages.copy()
|
| 329 |
+
pages.update(self.not_found_pages)
|
| 330 |
+
return {
|
| 331 |
+
'directory': self.directory,
|
| 332 |
+
'pages': pages,
|
| 333 |
+
'current_img': self.current_img,
|
| 334 |
+
'enable_watermark': self.enable_watermark,
|
| 335 |
+
'watermark_path': self.watermark_path,
|
| 336 |
+
'watermark_opacity': self.watermark_opacity
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
def read_img(self, imgname: str) -> np.ndarray:
|
| 340 |
+
if imgname not in self.pages:
|
| 341 |
+
raise ImgnameNotInProjectException
|
| 342 |
+
return imread(osp.join(self.directory, imgname))
|
| 343 |
+
|
| 344 |
+
def save_mask(self, img_name, mask: np.ndarray):
|
| 345 |
+
imwrite(self.get_mask_path(img_name), mask, ext=pcfg.intermediate_imgsave_ext)
|
| 346 |
+
|
| 347 |
+
def save_inpainted(self, img_name, inpainted: np.ndarray):
|
| 348 |
+
imwrite(self.get_inpainted_path(img_name), inpainted, ext=pcfg.intermediate_imgsave_ext)
|
| 349 |
+
|
| 350 |
+
def current_img_path(self) -> str:
|
| 351 |
+
if self.current_img is None:
|
| 352 |
+
return None
|
| 353 |
+
return osp.join(self.directory, self.current_img)
|
| 354 |
+
|
| 355 |
+
def get_mask_path(self, imgname: str = None, get_last_modified=False) -> str:
|
| 356 |
+
if imgname is None:
|
| 357 |
+
imgname = self.current_img
|
| 358 |
+
|
| 359 |
+
fileprefix = osp.join(self.mask_dir(), osp.splitext(imgname)[0])
|
| 360 |
+
if get_last_modified:
|
| 361 |
+
p = get_last_modified_file(fileprefix, ['.jxl', '.png'], ext_fallback=pcfg.intermediate_imgsave_ext)
|
| 362 |
+
else:
|
| 363 |
+
p = fileprefix+pcfg.intermediate_imgsave_ext
|
| 364 |
+
|
| 365 |
+
return p
|
| 366 |
+
|
| 367 |
+
def load_mask_by_imgname(self, imgname: str) -> np.ndarray:
|
| 368 |
+
mask = None
|
| 369 |
+
mp = self.get_mask_path(imgname, get_last_modified=True)
|
| 370 |
+
if osp.exists(mp):
|
| 371 |
+
mask = imread(mp, cv2.IMREAD_GRAYSCALE)
|
| 372 |
+
return mask
|
| 373 |
+
|
| 374 |
+
def get_inpainted_path(self, imgname: str = None, get_last_modified=False) -> str:
|
| 375 |
+
if imgname is None:
|
| 376 |
+
imgname = self.current_img
|
| 377 |
+
|
| 378 |
+
fileprefix = osp.join(self.inpainted_dir(), osp.splitext(imgname)[0])
|
| 379 |
+
if get_last_modified:
|
| 380 |
+
p = get_last_modified_file(fileprefix, ['.jxl', '.png'], ext_fallback=pcfg.intermediate_imgsave_ext)
|
| 381 |
+
else:
|
| 382 |
+
p = fileprefix+pcfg.intermediate_imgsave_ext
|
| 383 |
+
|
| 384 |
+
if not osp.exists(p) and shared.FUZZY_MATCH_IMAGE_NAME:
|
| 385 |
+
if self._fuzzy_inpainted_list is None:
|
| 386 |
+
if osp.exists(self.inpainted_dir()):
|
| 387 |
+
self._fuzzy_inpainted_list = find_all_imgs(self.inpainted_dir(), sort=True)
|
| 388 |
+
else:
|
| 389 |
+
self._fuzzy_inpainted_list = []
|
| 390 |
+
pidx = self.pagename2idx(imgname)
|
| 391 |
+
if pidx < len(self._fuzzy_inpainted_list):
|
| 392 |
+
return osp.join(self.inpainted_dir(), self._fuzzy_inpainted_list[pidx])
|
| 393 |
+
return p
|
| 394 |
+
|
| 395 |
+
def load_inpainted_by_imgname(self, imgname: str, scale_to_src: bool = True) -> np.ndarray:
|
| 396 |
+
inpainted = None
|
| 397 |
+
mp = self.get_inpainted_path(imgname, get_last_modified=True)
|
| 398 |
+
if mp is not None and osp.exists(mp):
|
| 399 |
+
inpainted = imread(mp)
|
| 400 |
+
if imgname == self.current_img and self.img_array is not None:
|
| 401 |
+
h, w = self.img_array.shape[:2]
|
| 402 |
+
else:
|
| 403 |
+
i = Image.open(osp.join(self.directory, imgname))
|
| 404 |
+
h, w = i.height, i.width
|
| 405 |
+
ih, iw = inpainted.shape[:2]
|
| 406 |
+
if ih != h or iw != w:
|
| 407 |
+
inpainted = Image.fromarray(inpainted).resize((w, h), resample=Image.Resampling.LANCZOS)
|
| 408 |
+
inpainted = np.array(inpainted)
|
| 409 |
+
return inpainted
|
| 410 |
+
|
| 411 |
+
def get_result_path(self, imgname: str) -> str:
|
| 412 |
+
ext = '.png'
|
| 413 |
+
if pcfg is not None:
|
| 414 |
+
if pcfg.imgsave_ext not in {'.jpg', '.png', '.webp', '.jxl'}:
|
| 415 |
+
LOGGER.warning('invalid image saving ext in config.json')
|
| 416 |
+
else:
|
| 417 |
+
ext = pcfg.imgsave_ext
|
| 418 |
+
return osp.join(self.result_dir(), osp.splitext(imgname)[0]+ext)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def backup(self):
|
| 422 |
+
raise NotImplementedError
|
| 423 |
+
|
| 424 |
+
@property
|
| 425 |
+
def is_empty(self):
|
| 426 |
+
return len(self.pages) == 0
|
| 427 |
+
|
| 428 |
+
@property
|
| 429 |
+
def is_all_pages_no_text(self):
|
| 430 |
+
return all([len(blklist) == 0 for blklist in self.pages.values()])
|
| 431 |
+
|
| 432 |
+
@property
|
| 433 |
+
def img_valid(self):
|
| 434 |
+
return self.img_array is not None
|
| 435 |
+
|
| 436 |
+
@property
|
| 437 |
+
def mask_valid(self):
|
| 438 |
+
return self.mask_array is not None
|
| 439 |
+
|
| 440 |
+
@property
|
| 441 |
+
def inpainted_valid(self):
|
| 442 |
+
return self.inpainted_array is not None
|
| 443 |
+
|
| 444 |
+
def set_next_img(self):
|
| 445 |
+
if self.current_img is not None:
|
| 446 |
+
next_idx = (self.current_idx + 1) % self.num_pages
|
| 447 |
+
self.set_current_img(self.idx2pagename(next_idx))
|
| 448 |
+
|
| 449 |
+
def set_prev_img(self):
|
| 450 |
+
if self.current_img is not None:
|
| 451 |
+
next_idx = (self.current_idx - 1 + self.num_pages) % self.num_pages
|
| 452 |
+
self.set_current_img(self.idx2pagename(next_idx))
|
| 453 |
+
|
| 454 |
+
def current_block_list(self) -> List[TextBlock]:
|
| 455 |
+
if self.current_img is not None:
|
| 456 |
+
assert self.current_img in self.pages
|
| 457 |
+
return self.pages[self.current_img]
|
| 458 |
+
else:
|
| 459 |
+
return None
|
| 460 |
+
|
| 461 |
+
def doc_path(self) -> str:
|
| 462 |
+
return os.path.join(self.directory, self.proj_name() + ".docx")
|
| 463 |
+
|
| 464 |
+
def doc_exist(self) -> bool:
|
| 465 |
+
return osp.exists(self.doc_path())
|
| 466 |
+
|
| 467 |
+
def dump_doc(self, delete_tmp_folder=True, fin_page_signal=None):
|
| 468 |
+
|
| 469 |
+
cuts_dir = os.path.join(self.directory, "bubcuts")
|
| 470 |
+
if os.path.exists(cuts_dir):
|
| 471 |
+
shutil.rmtree(cuts_dir)
|
| 472 |
+
os.mkdir(cuts_dir)
|
| 473 |
+
|
| 474 |
+
document = Document()
|
| 475 |
+
style = document.styles['Normal']
|
| 476 |
+
font = style.font
|
| 477 |
+
target_font = 'Arial'
|
| 478 |
+
font.name = target_font
|
| 479 |
+
for pagename, blklist in self.pages.items():
|
| 480 |
+
imgpath = os.path.join(self.directory, pagename)
|
| 481 |
+
|
| 482 |
+
cuts_path_list, cut_width_list = gen_ballon_cuts(cuts_dir, imgpath, blklist)
|
| 483 |
+
paragraph = document.add_paragraph(pagename)
|
| 484 |
+
paragraph.style = document.styles['Normal']
|
| 485 |
+
table = document.add_table(rows=len(cuts_path_list), cols=2, style='Table Grid')
|
| 486 |
+
|
| 487 |
+
for index, (cut_path, width) in enumerate(zip(cuts_path_list, cut_width_list)):
|
| 488 |
+
run = table.cell(index, 0).paragraphs[0].add_run()
|
| 489 |
+
run.style.font.name = target_font
|
| 490 |
+
blk: TextBlock = blklist[index]
|
| 491 |
+
bubdict = vars(blk).copy()
|
| 492 |
+
bubdict["imgkey"] = pagename
|
| 493 |
+
bubdict["rich_text"] = ''
|
| 494 |
+
bubdict["text"] = blk.get_text()
|
| 495 |
+
write_jpg_metadata(cut_path, metadata=json.dumps(bubdict, ensure_ascii=False, cls=TextBlkEncoder))
|
| 496 |
+
run.add_picture(cut_path, width=Inches(width/96 * 0.85))
|
| 497 |
+
table.cell(index, 1).text = bubdict["translation"]
|
| 498 |
+
|
| 499 |
+
document.add_page_break()
|
| 500 |
+
|
| 501 |
+
if fin_page_signal is not None:
|
| 502 |
+
fin_page_signal.emit()
|
| 503 |
+
# time.sleep(1)
|
| 504 |
+
|
| 505 |
+
doc_path = self.doc_path()
|
| 506 |
+
document.save(doc_path)
|
| 507 |
+
if delete_tmp_folder:
|
| 508 |
+
shutil.rmtree(cuts_dir)
|
| 509 |
+
|
| 510 |
+
def dump_txt_path(self, dump_target, suffix):
|
| 511 |
+
save_path = osp.join(self.directory, self.proj_name() + f'_{dump_target}{suffix}')
|
| 512 |
+
return save_path
|
| 513 |
+
|
| 514 |
+
def dump_txt(self, dump_target: str, suffix='.txt'):
|
| 515 |
+
save_path = self.dump_txt_path(dump_target, suffix=suffix)
|
| 516 |
+
text_all = []
|
| 517 |
+
assert dump_target in {'source', 'translation'}
|
| 518 |
+
assert suffix in {'.txt', '.md'}
|
| 519 |
+
for page_name, blk_list in self.pages.items():
|
| 520 |
+
text_in_page = ['### ' + page_name]
|
| 521 |
+
for ii, blk in enumerate(blk_list):
|
| 522 |
+
if dump_target == 'translation':
|
| 523 |
+
text = blk.translation.strip()
|
| 524 |
+
elif dump_target == 'source':
|
| 525 |
+
text = blk.get_text().strip()
|
| 526 |
+
text_in_page.append(f'{ii + 1}. {text}')
|
| 527 |
+
text_all.append('\n\n'.join(text_in_page))
|
| 528 |
+
with open(save_path, 'w', encoding='utf8') as f:
|
| 529 |
+
f.write('\n\n\n'.join(text_all))
|
| 530 |
+
|
| 531 |
+
def load_doc(self, doc_path, delete_tmp_folder=True, fin_page_signal=None):
|
| 532 |
+
tmp_bubble_folder = osp.join(self.directory, 'img_folder')
|
| 533 |
+
os.makedirs(tmp_bubble_folder, exist_ok=True)
|
| 534 |
+
docx2txt.process(doc_path, tmp_bubble_folder)
|
| 535 |
+
|
| 536 |
+
doc = docx.Document(doc_path)
|
| 537 |
+
body_xml_str = doc._body._element.xml
|
| 538 |
+
|
| 539 |
+
pages = {}
|
| 540 |
+
bub_index = 0
|
| 541 |
+
for tbl in re.findall(r'<w:tbl>(.*?)</w:tbl>', body_xml_str, re.DOTALL):
|
| 542 |
+
for tr in re.findall(r'<w:tr(.*?)>(.*?)</w:tr>', tbl, re.DOTALL):
|
| 543 |
+
if re.findall(r'<pic:cNvPr id=\"(.*?)\" name=\"(.*?)\"(.*?)>', tr[1]):
|
| 544 |
+
bub_index += 1
|
| 545 |
+
translation = ""
|
| 546 |
+
for paragraph in re.findall(r'<w:p(.*?)>(.*?)</w:p>', tr[1], re.DOTALL):
|
| 547 |
+
for wt in re.findall(r'<w:t>(.*?)</w:t>', paragraph[1], re.DOTALL):
|
| 548 |
+
translation += wt
|
| 549 |
+
translation += "\n"
|
| 550 |
+
translation = translation[:-1]
|
| 551 |
+
if len(translation) != 0 and translation[0] == "\n":
|
| 552 |
+
translation = translation[1:]
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
bubpath = os.path.join(tmp_bubble_folder, "image"+str(bub_index))
|
| 556 |
+
if osp.exists(bubpath+'.jpg'):
|
| 557 |
+
bubpath = bubpath + '.jpg'
|
| 558 |
+
else:
|
| 559 |
+
bubpath = bubpath + '.jpeg'
|
| 560 |
+
|
| 561 |
+
meta_dict = read_jpg_metadata(bubpath)
|
| 562 |
+
meta_dict["translation"] = translation
|
| 563 |
+
imgkey = meta_dict.pop("imgkey")
|
| 564 |
+
if not imgkey in pages:
|
| 565 |
+
pages[imgkey] = []
|
| 566 |
+
pages[imgkey].append(TextBlock(**meta_dict))
|
| 567 |
+
|
| 568 |
+
if fin_page_signal is not None:
|
| 569 |
+
fin_page_signal.emit()
|
| 570 |
+
|
| 571 |
+
self.merge_from_proj_dict(pages)
|
| 572 |
+
if delete_tmp_folder:
|
| 573 |
+
shutil.rmtree(tmp_bubble_folder)
|
| 574 |
+
|
| 575 |
+
def merge_from_proj_dict(self, tgt_dict: Dict) -> Dict:
|
| 576 |
+
if self.pages is None:
|
| 577 |
+
self.pages = {}
|
| 578 |
+
src_dict = self.pages if self.pages is not None else {}
|
| 579 |
+
key_lst = list(dict.fromkeys(list(src_dict.keys()) + list(tgt_dict.keys())))
|
| 580 |
+
key_lst.sort()
|
| 581 |
+
rst_dict = {}
|
| 582 |
+
pagename2idx = {}
|
| 583 |
+
idx2pagename = {}
|
| 584 |
+
page_counter = 0
|
| 585 |
+
for key in key_lst:
|
| 586 |
+
if key in src_dict and not key in tgt_dict:
|
| 587 |
+
rst_dict[key] = src_dict[key]
|
| 588 |
+
else:
|
| 589 |
+
rst_dict[key] = tgt_dict[key]
|
| 590 |
+
pagename2idx[key] = page_counter
|
| 591 |
+
idx2pagename[page_counter] = key
|
| 592 |
+
page_counter += 1
|
| 593 |
+
self.pages.clear()
|
| 594 |
+
self.pages.update(rst_dict)
|
| 595 |
+
self._pagename2idx = pagename2idx
|
| 596 |
+
self._idx2pagename = idx2pagename
|
| 597 |
+
|
| 598 |
+
|
| 599 |
+
def gen_ballon_cuts(cuts_dir: str, imgpath: str, blk_list: List[TextBlock], resize=True) -> Tuple[List[str], List[int]]:
|
| 600 |
+
img = imread(imgpath)
|
| 601 |
+
imgname = os.path.basename(imgpath)
|
| 602 |
+
cuts_path_list = []
|
| 603 |
+
cut_width_list = []
|
| 604 |
+
for ii, blk in enumerate(blk_list):
|
| 605 |
+
|
| 606 |
+
x, y, w, h = blk.bounding_rect()
|
| 607 |
+
x, y = max(x, 0), max(y, 0)
|
| 608 |
+
w = max(w, 1)
|
| 609 |
+
h = max(h, 1)
|
| 610 |
+
x1, y1, x2, y2 = int(x), int(y), int(x+w), int(y+h)
|
| 611 |
+
|
| 612 |
+
cut_path = os.path.join(cuts_dir, f'{imgname}-{ii}.jpg')
|
| 613 |
+
bub = img[y1:y2, x1:x2]
|
| 614 |
+
max_width = 448
|
| 615 |
+
|
| 616 |
+
if bub.shape[0] < 1 or bub.shape[1] < 1:
|
| 617 |
+
emptyw = 60
|
| 618 |
+
resized = np.full((emptyw, emptyw, 3), fill_value=0, dtype=np.uint8)
|
| 619 |
+
width = emptyw
|
| 620 |
+
else:
|
| 621 |
+
# scale_percent = 60 # percent of original size
|
| 622 |
+
scale_percent = min(1920 / img.shape[0], max_width / w)
|
| 623 |
+
|
| 624 |
+
if scale_percent < 1:
|
| 625 |
+
width = max(1, int(bub.shape[1] * scale_percent))
|
| 626 |
+
height = max(1, int(bub.shape[0] * scale_percent))
|
| 627 |
+
dim = (width, height)
|
| 628 |
+
resized = cv2.resize(bub, dim, interpolation = cv2.INTER_AREA) if resize else bub
|
| 629 |
+
else:
|
| 630 |
+
width = w
|
| 631 |
+
resized = bub
|
| 632 |
+
|
| 633 |
+
imwrite(cut_path, resized, '.jpg')
|
| 634 |
+
cuts_path_list.append(cut_path)
|
| 635 |
+
cut_width_list.append(width)
|
| 636 |
+
|
| 637 |
+
return cuts_path_list, cut_width_list
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
def save_image_with_watermark(
|
| 641 |
+
img: 'Union[np.ndarray, Image.Image]',
|
| 642 |
+
output_path: str,
|
| 643 |
+
watermark_path: str = None,
|
| 644 |
+
watermark_opacity: float = 0.7,
|
| 645 |
+
quality: int = 95
|
| 646 |
+
) -> bool:
|
| 647 |
+
"""Save image with optional watermark applied."""
|
| 648 |
+
try:
|
| 649 |
+
import cv2
|
| 650 |
+
from PIL import Image
|
| 651 |
+
if isinstance(img, np.ndarray):
|
| 652 |
+
if img.ndim == 2:
|
| 653 |
+
img_pil = Image.fromarray(img)
|
| 654 |
+
elif img.shape[2] == 4:
|
| 655 |
+
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA))
|
| 656 |
+
else:
|
| 657 |
+
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
| 658 |
+
else:
|
| 659 |
+
img_pil = img
|
| 660 |
+
|
| 661 |
+
import os.path as osp
|
| 662 |
+
from utils.watermark_utils import apply_watermark_to_pil_image
|
| 663 |
+
|
| 664 |
+
if watermark_path and osp.exists(watermark_path):
|
| 665 |
+
img_pil = apply_watermark_to_pil_image(img_pil, watermark_path, watermark_opacity)
|
| 666 |
+
|
| 667 |
+
ext = osp.splitext(output_path)[1].lower()
|
| 668 |
+
img_format = "PNG"
|
| 669 |
+
if ext in ['.jpg', '.jpeg']:
|
| 670 |
+
img_format = "JPEG"
|
| 671 |
+
if img_pil.mode == 'RGBA':
|
| 672 |
+
img_pil = img_pil.convert('RGB')
|
| 673 |
+
elif ext == '.webp':
|
| 674 |
+
img_format = "WEBP"
|
| 675 |
+
elif ext == '.jxl':
|
| 676 |
+
img_format = "JPEG2000"
|
| 677 |
+
|
| 678 |
+
save_kwargs = {'format': img_format}
|
| 679 |
+
if img_format in ['JPEG', 'WEBP', 'JPEG2000']:
|
| 680 |
+
save_kwargs['quality'] = quality
|
| 681 |
+
elif img_format == 'PNG':
|
| 682 |
+
save_kwargs['compress_level'] = 3
|
| 683 |
+
|
| 684 |
+
img_pil.save(output_path, **save_kwargs)
|
| 685 |
+
return True
|
| 686 |
+
except Exception as e:
|
| 687 |
+
LOGGER.error(f"Error saving image with watermark: {str(e)}")
|
| 688 |
+
return False
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
def save_result(self, imgname: str, img: np.ndarray) -> bool:
|
| 692 |
+
output_path = self.get_result_path(imgname)
|
| 693 |
+
|
| 694 |
+
if self.watermark_enabled and self.watermark_path:
|
| 695 |
+
# تحويل إلى صورة PIL
|
| 696 |
+
if img.ndim == 2:
|
| 697 |
+
img_pil = Image.fromarray(img)
|
| 698 |
+
elif img.shape[2] == 3:
|
| 699 |
+
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
| 700 |
+
elif img.shape[2] == 4:
|
| 701 |
+
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA))
|
| 702 |
+
|
| 703 |
+
# تطبيق العلامة المائية
|
| 704 |
+
img_pil = apply_watermark_to_pil_image(
|
| 705 |
+
img_pil,
|
| 706 |
+
self.watermark_path,
|
| 707 |
+
self.watermark_opacity
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# التحويل مرة أخرى إلى numpy array
|
| 711 |
+
if img_pil.mode == 'RGB':
|
| 712 |
+
img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
| 713 |
+
elif img_pil.mode == 'RGBA':
|
| 714 |
+
img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGBA2BGRA)
|
| 715 |
+
else:
|
| 716 |
+
img = np.array(img_pil)
|
| 717 |
+
|
| 718 |
+
return imwrite(output_path, img)
|
| 719 |
+
|
| 720 |
+
|
utils/registry.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/registry.py
|
| 2 |
+
|
| 3 |
+
import inspect
|
| 4 |
+
import warnings
|
| 5 |
+
from functools import partial
|
| 6 |
+
|
| 7 |
+
class Registry:
|
| 8 |
+
"""A registry to map strings to classes.
|
| 9 |
+
|
| 10 |
+
Registered object could be built from registry.
|
| 11 |
+
|
| 12 |
+
Example:
|
| 13 |
+
>>> MODELS = Registry('models')
|
| 14 |
+
>>> @MODELS.register_module()
|
| 15 |
+
>>> class ResNet:
|
| 16 |
+
>>> pass
|
| 17 |
+
>>> resnet = MODELS.build(dict(type='ResNet'))
|
| 18 |
+
|
| 19 |
+
Please refer to
|
| 20 |
+
https://mmcv.readthedocs.io/en/latest/understand_mmcv/registry.html for
|
| 21 |
+
advanced usage.
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
name (str): Registry name.
|
| 25 |
+
build_func(func, optional): Build function to construct instance from
|
| 26 |
+
Registry, func:`build_from_cfg` is used if neither ``parent`` or
|
| 27 |
+
``build_func`` is specified. If ``parent`` is specified and
|
| 28 |
+
``build_func`` is not given, ``build_func`` will be inherited
|
| 29 |
+
from ``parent``. Default: None.
|
| 30 |
+
parent (Registry, optional): Parent registry. The class registered in
|
| 31 |
+
children registry could be built from parent. Default: None.
|
| 32 |
+
scope (str, optional): The scope of registry. It is the key to search
|
| 33 |
+
for children registry. If not specified, scope will be the name of
|
| 34 |
+
the package where class is defined, e.g. mmdet, mmcls, mmseg.
|
| 35 |
+
Default: None.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, name, build_func=None, parent=None, scope=None):
|
| 39 |
+
self._name = name
|
| 40 |
+
self._module_dict = dict()
|
| 41 |
+
self._children = dict()
|
| 42 |
+
|
| 43 |
+
# self._scope = self.infer_scope() if scope is None else scope
|
| 44 |
+
|
| 45 |
+
# self.build_func will be set with the following priority:
|
| 46 |
+
# 1. build_func
|
| 47 |
+
# 2. parent.build_func
|
| 48 |
+
# 3. build_from_cfg
|
| 49 |
+
# if build_func is None:
|
| 50 |
+
# if parent is not None:
|
| 51 |
+
# self.build_func = parent.build_func
|
| 52 |
+
# else:
|
| 53 |
+
# self.build_func = build_from_cfg
|
| 54 |
+
# else:
|
| 55 |
+
# self.build_func = build_func
|
| 56 |
+
if parent is not None:
|
| 57 |
+
assert isinstance(parent, Registry)
|
| 58 |
+
parent._add_children(self)
|
| 59 |
+
self.parent = parent
|
| 60 |
+
else:
|
| 61 |
+
self.parent = None
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self._module_dict)
|
| 65 |
+
|
| 66 |
+
def __contains__(self, key):
|
| 67 |
+
return self.get(key) is not None
|
| 68 |
+
|
| 69 |
+
def __repr__(self):
|
| 70 |
+
format_str = self.__class__.__name__ + \
|
| 71 |
+
f'(name={self._name}, ' \
|
| 72 |
+
f'items={self._module_dict})'
|
| 73 |
+
return format_str
|
| 74 |
+
|
| 75 |
+
@staticmethod
|
| 76 |
+
def infer_scope():
|
| 77 |
+
"""Infer the scope of registry.
|
| 78 |
+
|
| 79 |
+
The name of the package where registry is defined will be returned.
|
| 80 |
+
|
| 81 |
+
Example:
|
| 82 |
+
>>> # in mmdet/models/backbone/resnet.py
|
| 83 |
+
>>> MODELS = Registry('models')
|
| 84 |
+
>>> @MODELS.register_module()
|
| 85 |
+
>>> class ResNet:
|
| 86 |
+
>>> pass
|
| 87 |
+
The scope of ``ResNet`` will be ``mmdet``.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
str: The inferred scope name.
|
| 91 |
+
"""
|
| 92 |
+
# inspect.stack() trace where this function is called, the index-2
|
| 93 |
+
# indicates the frame where `infer_scope()` is called
|
| 94 |
+
filename = inspect.getmodule(inspect.stack()[2][0]).__name__
|
| 95 |
+
split_filename = filename.split('.')
|
| 96 |
+
return split_filename[0]
|
| 97 |
+
|
| 98 |
+
@staticmethod
|
| 99 |
+
def split_scope_key(key):
|
| 100 |
+
"""Split scope and key.
|
| 101 |
+
|
| 102 |
+
The first scope will be split from key.
|
| 103 |
+
|
| 104 |
+
Examples:
|
| 105 |
+
>>> Registry.split_scope_key('mmdet.ResNet')
|
| 106 |
+
'mmdet', 'ResNet'
|
| 107 |
+
>>> Registry.split_scope_key('ResNet')
|
| 108 |
+
None, 'ResNet'
|
| 109 |
+
|
| 110 |
+
Return:
|
| 111 |
+
tuple[str | None, str]: The former element is the first scope of
|
| 112 |
+
the key, which can be ``None``. The latter is the remaining key.
|
| 113 |
+
"""
|
| 114 |
+
split_index = key.find('.')
|
| 115 |
+
if split_index != -1:
|
| 116 |
+
return key[:split_index], key[split_index + 1:]
|
| 117 |
+
else:
|
| 118 |
+
return None, key
|
| 119 |
+
|
| 120 |
+
@property
|
| 121 |
+
def name(self):
|
| 122 |
+
return self._name
|
| 123 |
+
|
| 124 |
+
# @property
|
| 125 |
+
# def scope(self):
|
| 126 |
+
# return self._scope
|
| 127 |
+
|
| 128 |
+
@property
|
| 129 |
+
def module_dict(self):
|
| 130 |
+
return self._module_dict
|
| 131 |
+
|
| 132 |
+
@property
|
| 133 |
+
def children(self):
|
| 134 |
+
return self._children
|
| 135 |
+
|
| 136 |
+
def get(self, key):
|
| 137 |
+
"""Get the registry record.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
key (str): The class name in string format.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
class: The corresponding class.
|
| 144 |
+
"""
|
| 145 |
+
scope, real_key = self.split_scope_key(key)
|
| 146 |
+
if scope is None or scope == self._scope:
|
| 147 |
+
# get from self
|
| 148 |
+
if real_key in self._module_dict:
|
| 149 |
+
return self._module_dict[real_key]
|
| 150 |
+
else:
|
| 151 |
+
# get from self._children
|
| 152 |
+
if scope in self._children:
|
| 153 |
+
return self._children[scope].get(real_key)
|
| 154 |
+
else:
|
| 155 |
+
# goto root
|
| 156 |
+
parent = self.parent
|
| 157 |
+
while parent.parent is not None:
|
| 158 |
+
parent = parent.parent
|
| 159 |
+
return parent.get(key)
|
| 160 |
+
|
| 161 |
+
# def build(self, *args, **kwargs):
|
| 162 |
+
# return self.build_func(*args, **kwargs, registry=self)
|
| 163 |
+
|
| 164 |
+
def _add_children(self, registry):
|
| 165 |
+
"""Add children for a registry.
|
| 166 |
+
|
| 167 |
+
The ``registry`` will be added as children based on its scope.
|
| 168 |
+
The parent registry could build objects from children registry.
|
| 169 |
+
|
| 170 |
+
Example:
|
| 171 |
+
>>> models = Registry('models')
|
| 172 |
+
>>> mmdet_models = Registry('models', parent=models)
|
| 173 |
+
>>> @mmdet_models.register_module()
|
| 174 |
+
>>> class ResNet:
|
| 175 |
+
>>> pass
|
| 176 |
+
>>> resnet = models.build(dict(type='mmdet.ResNet'))
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
assert isinstance(registry, Registry)
|
| 180 |
+
assert registry.scope is not None
|
| 181 |
+
assert registry.scope not in self.children, \
|
| 182 |
+
f'scope {registry.scope} exists in {self.name} registry'
|
| 183 |
+
self.children[registry.scope] = registry
|
| 184 |
+
|
| 185 |
+
def _register_module(self, module_class, module_name=None, force=False):
|
| 186 |
+
if not inspect.isclass(module_class):
|
| 187 |
+
raise TypeError('module must be a class, '
|
| 188 |
+
f'but got {type(module_class)}')
|
| 189 |
+
|
| 190 |
+
if module_name is None:
|
| 191 |
+
module_name = module_class.__name__
|
| 192 |
+
if isinstance(module_name, str):
|
| 193 |
+
module_name = [module_name]
|
| 194 |
+
|
| 195 |
+
for name in module_name:
|
| 196 |
+
if not force and name in self._module_dict:
|
| 197 |
+
raise KeyError(f'{name} is already registered '
|
| 198 |
+
f'in {self.name}')
|
| 199 |
+
self._module_dict[name] = module_class
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def deprecated_register_module(self, cls=None, force=False):
|
| 203 |
+
warnings.warn(
|
| 204 |
+
'The old API of register_module(module, force=False) '
|
| 205 |
+
'is deprecated and will be removed, please use the new API '
|
| 206 |
+
'register_module(name=None, force=False, module=None) instead.',
|
| 207 |
+
DeprecationWarning)
|
| 208 |
+
if cls is None:
|
| 209 |
+
return partial(self.deprecated_register_module, force=force)
|
| 210 |
+
self._register_module(cls, force=force)
|
| 211 |
+
return cls
|
| 212 |
+
|
| 213 |
+
def register_module(self, name=None, force=False, module=None):
|
| 214 |
+
"""Register a module.
|
| 215 |
+
|
| 216 |
+
A record will be added to `self._module_dict`, whose key is the class
|
| 217 |
+
name or the specified name, and value is the class itself.
|
| 218 |
+
It can be used as a decorator or a normal function.
|
| 219 |
+
|
| 220 |
+
Example:
|
| 221 |
+
>>> backbones = Registry('backbone')
|
| 222 |
+
>>> @backbones.register_module()
|
| 223 |
+
>>> class ResNet:
|
| 224 |
+
>>> pass
|
| 225 |
+
|
| 226 |
+
>>> backbones = Registry('backbone')
|
| 227 |
+
>>> @backbones.register_module(name='mnet')
|
| 228 |
+
>>> class MobileNet:
|
| 229 |
+
>>> pass
|
| 230 |
+
|
| 231 |
+
>>> backbones = Registry('backbone')
|
| 232 |
+
>>> class ResNet:
|
| 233 |
+
>>> pass
|
| 234 |
+
>>> backbones.register_module(ResNet)
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
name (str | None): The module name to be registered. If not
|
| 238 |
+
specified, the class name will be used.
|
| 239 |
+
force (bool, optional): Whether to override an existing class with
|
| 240 |
+
the same name. Default: False.
|
| 241 |
+
module (type): Module class to be registered.
|
| 242 |
+
"""
|
| 243 |
+
if not isinstance(force, bool):
|
| 244 |
+
raise TypeError(f'force must be a boolean, but got {type(force)}')
|
| 245 |
+
# NOTE: This is a walkaround to be compatible with the old api,
|
| 246 |
+
# while it may introduce unexpected bugs.
|
| 247 |
+
if isinstance(name, type):
|
| 248 |
+
return self.deprecated_register_module(name, force=force)
|
| 249 |
+
|
| 250 |
+
# raise the error ahead of time
|
| 251 |
+
if not (name is None or isinstance(name, str)):
|
| 252 |
+
raise TypeError(
|
| 253 |
+
'name must be either of None, an instance of str or a sequence'
|
| 254 |
+
f' of str, but got {type(name)}')
|
| 255 |
+
|
| 256 |
+
# use it as a normal method: x.register_module(module=SomeClass)
|
| 257 |
+
if module is not None:
|
| 258 |
+
|
| 259 |
+
self._register_module(
|
| 260 |
+
module_class=module, module_name=name, force=force)
|
| 261 |
+
return module
|
| 262 |
+
|
| 263 |
+
# use it as a decorator: @x.register_module()
|
| 264 |
+
def _register(cls):
|
| 265 |
+
self._register_module(
|
| 266 |
+
module_class=cls, module_name=name, force=force)
|
| 267 |
+
return cls
|
| 268 |
+
|
| 269 |
+
return _register
|
| 270 |
+
|
| 271 |
+
def __getitem__(self, key: str):
|
| 272 |
+
return self.get(key)
|
utils/shared.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict
|
| 2 |
+
import os
|
| 3 |
+
import os.path as osp
|
| 4 |
+
import json
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
ICON_PATH = 'icons/icon.icns'
|
| 8 |
+
|
| 9 |
+
PROGRAM_PATH = osp.abspath(osp.dirname(osp.dirname(__file__)))
|
| 10 |
+
LOGGING_PATH = osp.join(PROGRAM_PATH, 'logs')
|
| 11 |
+
|
| 12 |
+
LIBS_PATH = osp.join(PROGRAM_PATH, 'data/libs')
|
| 13 |
+
|
| 14 |
+
STYLESHEET_PATH = osp.join(PROGRAM_PATH, 'config/stylesheet.css')
|
| 15 |
+
THEME_PATH = osp.join(PROGRAM_PATH, 'config/themes.json')
|
| 16 |
+
CONFIG_PATH = osp.join(PROGRAM_PATH, 'config/config.json')
|
| 17 |
+
|
| 18 |
+
DEFAULT_TEXTSTYLE_DIR = osp.join(PROGRAM_PATH, 'config/textstyles')
|
| 19 |
+
if not osp.exists(DEFAULT_TEXTSTYLE_DIR):
|
| 20 |
+
os.makedirs(DEFAULT_TEXTSTYLE_DIR)
|
| 21 |
+
|
| 22 |
+
st_manager = None
|
| 23 |
+
CONFIG_FONTSIZE_HEADER = 18
|
| 24 |
+
CONFIG_FONTSIZE_TABLE = 16
|
| 25 |
+
CONFIG_FONTSIZE_CONTENT = 16
|
| 26 |
+
|
| 27 |
+
CONFIG_COMBOBOX_HEIGHT = 30
|
| 28 |
+
CONFIG_COMBOBOX_SHORT = 200
|
| 29 |
+
CONFIG_COMBOBOX_MIDEAN = 332
|
| 30 |
+
CONFIG_COMBOBOX_LONG = 468
|
| 31 |
+
|
| 32 |
+
_size2width = {
|
| 33 |
+
'short': CONFIG_COMBOBOX_SHORT,
|
| 34 |
+
'median': CONFIG_COMBOBOX_MIDEAN,
|
| 35 |
+
'long':CONFIG_COMBOBOX_LONG
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
def size2width(size: str):
|
| 39 |
+
global _size2width
|
| 40 |
+
return _size2width[size]
|
| 41 |
+
|
| 42 |
+
HORSLIDER_FIXHEIGHT = 36
|
| 43 |
+
|
| 44 |
+
WIDGET_SPACING_CLOSE = 8
|
| 45 |
+
TEXTEDIT_FIXWIDTH = 350
|
| 46 |
+
|
| 47 |
+
TEXTEFFECT_FIXWIDTH = 400
|
| 48 |
+
TEXTEFFECT_MAXHEIGHT = 500
|
| 49 |
+
|
| 50 |
+
LEFTBAR_WIDTH = 48
|
| 51 |
+
LEFTBTN_WIDTH = 28
|
| 52 |
+
|
| 53 |
+
LDPI = 96.
|
| 54 |
+
DPI = 188.75
|
| 55 |
+
|
| 56 |
+
SCREEN_H = 2160
|
| 57 |
+
SCREEN_W = 3840
|
| 58 |
+
|
| 59 |
+
DEFAULT_FONT_FAMILY = 'Microsoft YaHei UI'
|
| 60 |
+
APP_DEFAULT_FONT = 'Microsoft YaHei UI'
|
| 61 |
+
|
| 62 |
+
WINDOW_BORDER_WIDTH = 4
|
| 63 |
+
BOTTOMBAR_HEIGHT = 32
|
| 64 |
+
TITLEBAR_HEIGHT = 30
|
| 65 |
+
|
| 66 |
+
PAGELIST_THUMBNAIL_MAXNUM = 100
|
| 67 |
+
PAGELIST_THUMBNAIL_SIZE = 48
|
| 68 |
+
|
| 69 |
+
FLAG_QT6 = True
|
| 70 |
+
|
| 71 |
+
SLIDERHANDLE_COLOR = (85,85,96)
|
| 72 |
+
FOREGROUND_FONTCOLOR = (93,93,95)
|
| 73 |
+
|
| 74 |
+
MAX_NUM_LOG = 7
|
| 75 |
+
|
| 76 |
+
TRANSLATE_DIR = osp.join(PROGRAM_PATH, 'translate')
|
| 77 |
+
DISPLAY_LANGUAGE_MAP = {
|
| 78 |
+
"English": "English",
|
| 79 |
+
"简体中文": "zh_CN",
|
| 80 |
+
"Русский": "ru_RU",
|
| 81 |
+
"Português (Brasil)": "pt_BR",
|
| 82 |
+
"한국어": "ko_KR",
|
| 83 |
+
"Español": "es_MX",
|
| 84 |
+
"Hungarian": "hu_HU"
|
| 85 |
+
}
|
| 86 |
+
VALID_LANG_SET = set(list(DISPLAY_LANGUAGE_MAP.values()))
|
| 87 |
+
|
| 88 |
+
for p in os.listdir(TRANSLATE_DIR):
|
| 89 |
+
if p.endswith('.qm'):
|
| 90 |
+
lang = p.replace('.qm', '')
|
| 91 |
+
if lang not in VALID_LANG_SET:
|
| 92 |
+
DISPLAY_LANGUAGE_MAP[lang] = lang
|
| 93 |
+
|
| 94 |
+
DEFAULT_DISPLAY_LANG = 'English'
|
| 95 |
+
|
| 96 |
+
USE_PYSIDE6 = False
|
| 97 |
+
ON_MACOS = sys.platform == 'darwin'
|
| 98 |
+
ON_WINDOWS = sys.platform == 'win32'
|
| 99 |
+
HEADLESS = False
|
| 100 |
+
DEBUG = False
|
| 101 |
+
args = None
|
| 102 |
+
|
| 103 |
+
FUZZY_MATCH_IMAGE_NAME = False
|
| 104 |
+
|
| 105 |
+
cache_data: Dict = None
|
| 106 |
+
cache_dir: str = osp.join(PROGRAM_PATH, '.btrans_cache')
|
| 107 |
+
cache_path: str = osp.join(PROGRAM_PATH, '.btrans_cache/cache.json')
|
| 108 |
+
CACHE_UPDATED = False
|
| 109 |
+
check_local_file_hash = True
|
| 110 |
+
|
| 111 |
+
FONT_FAMILIES: set = None
|
| 112 |
+
CUSTOM_FONTS = []
|
| 113 |
+
pbar = {}
|
| 114 |
+
runtime_widget_set = set()
|
| 115 |
+
|
| 116 |
+
def add_to_runtime_widget_set(widget):
|
| 117 |
+
runtime_widget_set.add(widget)
|
| 118 |
+
|
| 119 |
+
def remove_from_runtime_widget_set(widget):
|
| 120 |
+
if widget in runtime_widget_set:
|
| 121 |
+
runtime_widget_set.remove(widget)
|
| 122 |
+
|
| 123 |
+
showed_exception = set()
|
| 124 |
+
|
| 125 |
+
# it will be set to ui.mainwindow.create_errdialog.emit after UI initialized
|
| 126 |
+
create_errdialog_in_mainthread = lambda *args, **kwargs: None
|
| 127 |
+
|
| 128 |
+
create_infodialog_in_mainthread = lambda *args, **kwargs: None
|
| 129 |
+
|
| 130 |
+
def load_cache():
|
| 131 |
+
global cache_data
|
| 132 |
+
if cache_data is None:
|
| 133 |
+
if osp.exists(cache_path):
|
| 134 |
+
try:
|
| 135 |
+
with open(cache_path, "r", encoding="utf8") as file:
|
| 136 |
+
cache_data = json.load(file)
|
| 137 |
+
except:
|
| 138 |
+
print(f'cached file {cache_path} is invalid')
|
| 139 |
+
cache_data = {}
|
| 140 |
+
else:
|
| 141 |
+
cache_data = {}
|
| 142 |
+
|
| 143 |
+
def dump_cache():
|
| 144 |
+
global cache_data
|
| 145 |
+
if cache_data is None:
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
cache_dir = osp.dirname(cache_path)
|
| 149 |
+
if not osp.exists(cache_dir):
|
| 150 |
+
os.makedirs(cache_dir)
|
| 151 |
+
|
| 152 |
+
with open(cache_path, "w", encoding="utf8") as file:
|
| 153 |
+
json.dump(cache_data, file, indent=4)
|
| 154 |
+
|
| 155 |
+
global CACHE_UPDATED
|
| 156 |
+
CACHE_UPDATED = False
|
| 157 |
+
|
| 158 |
+
config_name_to_view_widget = {}
|
| 159 |
+
action_to_view_config_name = {}
|
| 160 |
+
register_view_widget: lambda *args, **kwargs: None
|
utils/split_text_region.py
ADDED
|
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2, os, re, random
|
| 2 |
+
import numpy as np
|
| 3 |
+
# import tesserocr
|
| 4 |
+
# from tesserocr import PyTessBaseAPI, PSM, OEM
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class TextSpan(object):
|
| 9 |
+
def __init__(self, top_bnd=None, bottom_bnd=None, left_bnd=None, right_bnd=None):
|
| 10 |
+
self.top = top_bnd
|
| 11 |
+
self.bottom = bottom_bnd
|
| 12 |
+
self.height = self.bottom - self.top if bottom_bnd is not None else None
|
| 13 |
+
|
| 14 |
+
self.left = left_bnd
|
| 15 |
+
self.right = right_bnd
|
| 16 |
+
self.width = self.right - self.left if right_bnd is not None else None
|
| 17 |
+
|
| 18 |
+
def set_top(self, top_bnd):
|
| 19 |
+
self.top = top_bnd
|
| 20 |
+
return True
|
| 21 |
+
|
| 22 |
+
def set_bottom(self, bottom_bnd):
|
| 23 |
+
if self.top is None or bottom_bnd <= self.top:
|
| 24 |
+
return False
|
| 25 |
+
self.bottom = bottom_bnd
|
| 26 |
+
self.height = self.bottom - self.top
|
| 27 |
+
return True
|
| 28 |
+
|
| 29 |
+
def set_left(self, left_bnd):
|
| 30 |
+
self.left = left_bnd
|
| 31 |
+
return True
|
| 32 |
+
|
| 33 |
+
def set_right(self, right_bnd):
|
| 34 |
+
if self.left is None or right_bnd <= self.left:
|
| 35 |
+
return False
|
| 36 |
+
self.right = right_bnd
|
| 37 |
+
self.width = right_bnd - self.left
|
| 38 |
+
return True
|
| 39 |
+
|
| 40 |
+
def __getitem__(self, index):
|
| 41 |
+
if isinstance(index, int) and index >=0 and index < 4:
|
| 42 |
+
return [self.left, self.top, self.right, self.bottom][index]
|
| 43 |
+
else:
|
| 44 |
+
raise AttributeError(f'Invalid key: {index}')
|
| 45 |
+
|
| 46 |
+
def split_step0(span, thresh, sumby_yaxis, thresh2=None) -> list[TextSpan]:
|
| 47 |
+
candidate_pnts = (np.where(sumby_yaxis[span.top: span.bottom] > thresh)[0] + span.top).tolist()
|
| 48 |
+
span_list = []
|
| 49 |
+
if len(candidate_pnts) == 0:
|
| 50 |
+
return None
|
| 51 |
+
stride_tol = 1
|
| 52 |
+
span0, span1 = TextSpan(candidate_pnts[0]), TextSpan()
|
| 53 |
+
for pnt_ind in range(len(candidate_pnts)-1):
|
| 54 |
+
if candidate_pnts[pnt_ind+1] - candidate_pnts[pnt_ind] > stride_tol:
|
| 55 |
+
if not span0.set_bottom(candidate_pnts[pnt_ind]):
|
| 56 |
+
continue
|
| 57 |
+
span_list = split_step1(span0, span_list, thresh=thresh2, sumby_yaxis=sumby_yaxis)
|
| 58 |
+
span1.set_top(candidate_pnts[pnt_ind+1])
|
| 59 |
+
span0 = span1
|
| 60 |
+
span1 = TextSpan()
|
| 61 |
+
|
| 62 |
+
if len(candidate_pnts)-1 == 0:
|
| 63 |
+
if candidate_pnts[0] == candidate_pnts[-1]:
|
| 64 |
+
span_list = None
|
| 65 |
+
else:
|
| 66 |
+
span0 = TextSpan(candidate_pnts[0], candidate_pnts[-1])
|
| 67 |
+
span_list = split_step1(span0, span_list, thresh=thresh2, sumby_yaxis=sumby_yaxis)
|
| 68 |
+
elif span0.top != candidate_pnts[-1]:
|
| 69 |
+
span0.set_bottom(candidate_pnts[-1])
|
| 70 |
+
span_list = split_step1(span0, span_list, thresh=thresh2, sumby_yaxis=sumby_yaxis)
|
| 71 |
+
|
| 72 |
+
return span_list
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def split_step1(span, span_list, thresh=None, sumby_yaxis=None):
|
| 77 |
+
if thresh is None:
|
| 78 |
+
span_list.append(span)
|
| 79 |
+
return span_list
|
| 80 |
+
else:
|
| 81 |
+
subspan_list = split_step0(span, thresh, sumby_yaxis)
|
| 82 |
+
# print(np.var(sumby_yaxis[span.top:span.bottom]))
|
| 83 |
+
if subspan_list is not None:
|
| 84 |
+
|
| 85 |
+
_, maxspan = find_span(subspan_list, max)
|
| 86 |
+
_, minspan = find_span(subspan_list, min)
|
| 87 |
+
|
| 88 |
+
sum_height = sum(c.height for c in subspan_list)
|
| 89 |
+
|
| 90 |
+
if maxspan.height / minspan.height > 2.5 or sum_height / span.height < 0.3 or len(subspan_list) == 1:
|
| 91 |
+
subspan_list = None
|
| 92 |
+
if subspan_list is not None and len(subspan_list) > 1:
|
| 93 |
+
span_list += subspan_list
|
| 94 |
+
else:
|
| 95 |
+
span_list.append(span)
|
| 96 |
+
return span_list
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def shrink_span_list(src_img, span_list, shrink_vert_space=True, shrink_hor_space=True):
|
| 101 |
+
height, width = src_img.shape[0], src_img.shape[1]
|
| 102 |
+
|
| 103 |
+
sum_spacing = 0
|
| 104 |
+
if shrink_vert_space:
|
| 105 |
+
for ii in range(len(span_list)-1):
|
| 106 |
+
line_spacing = span_list[ii+1].top - span_list[ii].bottom
|
| 107 |
+
sum_spacing += line_spacing
|
| 108 |
+
line_spacing = int(round(line_spacing / 2))
|
| 109 |
+
span_list[ii+1].top -= line_spacing
|
| 110 |
+
span_list[ii].set_bottom(span_list[ii].bottom + line_spacing)
|
| 111 |
+
|
| 112 |
+
if len(span_list) >= 2:
|
| 113 |
+
mean_spacing = int(0.5 * round(sum_spacing / (len(span_list)-1)))
|
| 114 |
+
span_list[0].top = max(0, span_list[0].top-mean_spacing)
|
| 115 |
+
span_list[0].set_bottom(span_list[0].bottom)
|
| 116 |
+
span_list[-1].set_bottom(min(src_img.shape[0], span_list[-1].bottom))
|
| 117 |
+
|
| 118 |
+
left_var, middle_var = -1, -1
|
| 119 |
+
if shrink_hor_space:
|
| 120 |
+
left_pnts, middle_pnts = [], []
|
| 121 |
+
for ii in range(len(span_list)):
|
| 122 |
+
s = span_list[ii]
|
| 123 |
+
im = src_img[s.top: s.bottom, 0: width]
|
| 124 |
+
sumby_yaxis = np.mean(im, axis=0)
|
| 125 |
+
content_array = np.where(sumby_yaxis > 10)[0].tolist()
|
| 126 |
+
left, right = 0, width
|
| 127 |
+
if len(content_array) != 0:
|
| 128 |
+
left, right = content_array[0], content_array[-1]
|
| 129 |
+
span_list[ii].set_left(left)
|
| 130 |
+
span_list[ii].set_right(right)
|
| 131 |
+
s = span_list[ii]
|
| 132 |
+
left_pnts.append(left)
|
| 133 |
+
middle_pnts.append((left+right)/2)
|
| 134 |
+
left_var, middle_var = np.var(np.array(left_pnts)), np.var(np.array(middle_pnts))
|
| 135 |
+
|
| 136 |
+
return span_list, (left_var, middle_var)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def find_span(span_list, max_or_min=max, key="height"):
|
| 141 |
+
if key=="height":
|
| 142 |
+
return max_or_min(enumerate(span_list), key=(lambda x: span_list[x[0]].height), default = -1)
|
| 143 |
+
else:
|
| 144 |
+
return max_or_min(enumerate(span_list), key=(lambda x: span_list[x[0]].width), default = -1)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def discard_spans(span_list, thresh_ratio=0.3):
|
| 149 |
+
index, max_span = find_span(span_list, max)
|
| 150 |
+
max_height = max_span.height
|
| 151 |
+
height_thresh = max_height * thresh_ratio
|
| 152 |
+
new_spanlist = []
|
| 153 |
+
for sp in span_list:
|
| 154 |
+
if sp.height < height_thresh:
|
| 155 |
+
continue
|
| 156 |
+
new_spanlist.append(sp)
|
| 157 |
+
|
| 158 |
+
return new_spanlist
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def plot_mapresult(sumbyvector, xlength, span_list=None, thresh=None):
|
| 163 |
+
'''for experiment'''
|
| 164 |
+
try:
|
| 165 |
+
import matplotlib.pyplot as plt
|
| 166 |
+
plt.plot(sumbyvector)
|
| 167 |
+
plt.ylabel('div pnt value')
|
| 168 |
+
plt.xlabel('div pnt coord')
|
| 169 |
+
s = [0, 255]
|
| 170 |
+
x_cords = []
|
| 171 |
+
if span_list is not None:
|
| 172 |
+
for sp in span_list:
|
| 173 |
+
x_cords.append(sp.top)
|
| 174 |
+
x_cords.append(sp.bottom)
|
| 175 |
+
if thresh is not None:
|
| 176 |
+
for tr in thresh:
|
| 177 |
+
plt.vlines(x = x_cords, ymin = 0, ymax = max(s),
|
| 178 |
+
colors = 'purple',
|
| 179 |
+
label = 'vline_multiple - full height')
|
| 180 |
+
plt.hlines(y = tr * sumbyvector.mean(), xmin = 0, xmax = xlength, linestyles='--')
|
| 181 |
+
plt.show()
|
| 182 |
+
except:
|
| 183 |
+
pass
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def box(width, height):
|
| 188 |
+
return np.ones((height, width), dtype=np.uint8)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def crop_img(img, crop_ratio=0.2, clip_width=True, dilate=False):
|
| 192 |
+
h, w = img.shape[:2]
|
| 193 |
+
moments = cv2.moments(img)
|
| 194 |
+
area = moments['m00']
|
| 195 |
+
if area != 0:
|
| 196 |
+
mean_x = int(round(moments['m10'] / area))
|
| 197 |
+
mean_y = int(round(moments['m01'] / area))
|
| 198 |
+
crop_r = int(round(crop_ratio * w))
|
| 199 |
+
if clip_width:
|
| 200 |
+
crop_x0 = np.clip(mean_x - crop_r, 0, w)
|
| 201 |
+
crop_x1 = np.clip(mean_x + crop_r, 0, w)
|
| 202 |
+
if crop_x1 > crop_x0:
|
| 203 |
+
img = img[:, crop_x0: crop_x1]
|
| 204 |
+
else:
|
| 205 |
+
crop_r = np.clip(crop_r * 2, 0, w - 1)
|
| 206 |
+
img = img[:, crop_r:]
|
| 207 |
+
img = np.copy(img)
|
| 208 |
+
if clip_width and dilate:
|
| 209 |
+
w = int(round(w/7))
|
| 210 |
+
if w > 1:
|
| 211 |
+
img = cv2.dilate(img, box(w, 1), 1)
|
| 212 |
+
return img, img.shape[0], img.shape[1]
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def split_textblock(src_img, crop_ratio=0.2, blur=False, show_process=False, discard=True, shrink=True, recheck=False, clip_width=True, dilate=True):
|
| 217 |
+
|
| 218 |
+
if blur:
|
| 219 |
+
src_img = cv2.GaussianBlur(src_img,(3,3),cv2.BORDER_DEFAULT)
|
| 220 |
+
if crop_ratio > 0:
|
| 221 |
+
img, height, width = crop_img(src_img, crop_ratio=crop_ratio, clip_width=clip_width, dilate=dilate)
|
| 222 |
+
else:
|
| 223 |
+
img, height, width = src_img, src_img.shape[0], src_img.shape[1]
|
| 224 |
+
|
| 225 |
+
sumby_yaxis = img.mean(axis=1)
|
| 226 |
+
bound0 = np.where(sumby_yaxis > sumby_yaxis.mean() * 0.1)[0].tolist()
|
| 227 |
+
vars = (-1, -1)
|
| 228 |
+
|
| 229 |
+
if len(bound0) < 2:
|
| 230 |
+
return [TextSpan(0, height-1, 0, width - 1)], vars
|
| 231 |
+
|
| 232 |
+
base_span = TextSpan(bound0[0], bound0[-1])
|
| 233 |
+
meanby_yaxis = sumby_yaxis.mean()
|
| 234 |
+
|
| 235 |
+
thresh_ratio = [0.4, 0.8]
|
| 236 |
+
thresh0 = meanby_yaxis * thresh_ratio[0]
|
| 237 |
+
thresh2 = meanby_yaxis * thresh_ratio[1]
|
| 238 |
+
|
| 239 |
+
span_list = split_step0(base_span, thresh0, sumby_yaxis, thresh2=thresh2)
|
| 240 |
+
if span_list is None:
|
| 241 |
+
return None, None
|
| 242 |
+
if discard:
|
| 243 |
+
span_list = discard_spans(span_list)
|
| 244 |
+
if shrink:
|
| 245 |
+
span_list, vars = shrink_span_list(src_img, span_list)
|
| 246 |
+
|
| 247 |
+
'''for experiment'''
|
| 248 |
+
if show_process:
|
| 249 |
+
plot_mapresult(sumby_yaxis, height, span_list=span_list, thresh=thresh_ratio)
|
| 250 |
+
|
| 251 |
+
if recheck and len(span_list) == 1 and crop_ratio > 0:
|
| 252 |
+
return split_textblock(src_img, crop_ratio==-1, show_process=show_process, discard=discard, shrink=shrink, recheck=False)
|
| 253 |
+
|
| 254 |
+
valid_span_list = []
|
| 255 |
+
for span in span_list:
|
| 256 |
+
if span.top is None:
|
| 257 |
+
span.set_top(0)
|
| 258 |
+
if span.left is None:
|
| 259 |
+
span.set_left(0)
|
| 260 |
+
if span.right is None:
|
| 261 |
+
span.set_right(width)
|
| 262 |
+
if span.bottom is None:
|
| 263 |
+
span.set_bottom(height)
|
| 264 |
+
valid_span_list.append(span)
|
| 265 |
+
|
| 266 |
+
return valid_span_list, vars
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# def tessocr_img2text(img, lang):
|
| 271 |
+
# img = Image.fromarray(img)
|
| 272 |
+
# if re.findall("vert", lang):
|
| 273 |
+
# psm = PSM.SINGLE_BLOCK_VERT_TEXT
|
| 274 |
+
# else:
|
| 275 |
+
# psm = PSM.SINGLE_LINE
|
| 276 |
+
# return tesserocr.image_to_text(img, psm=psm, lang=lang, path=TESSDATA_PATH)
|
| 277 |
+
|
| 278 |
+
# def tessocr_img2text(img, lang):
|
| 279 |
+
# psm = "5" if re.findall("vert", lang) else "7"
|
| 280 |
+
# config = r'--tessdata-dir "models\tessdata" --psm ' + psm
|
| 281 |
+
# return pytesseract.image_to_string(img, lang=lang, config=config)
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def textspan2list(span_list):
|
| 285 |
+
converted_list = []
|
| 286 |
+
for ii, s in enumerate(span_list):
|
| 287 |
+
converted_list.append([])
|
| 288 |
+
converted_list[ii].append(s.top)
|
| 289 |
+
converted_list[ii].append(s.left)
|
| 290 |
+
converted_list[ii].append(s.bottom)
|
| 291 |
+
converted_list[ii].append(s.right)
|
| 292 |
+
return converted_list
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def manga_split(img, bbox=None, show_process=False, clip_width=False) -> list[TextSpan]:
|
| 297 |
+
|
| 298 |
+
im = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
| 299 |
+
imh, imw = im.shape[:2]
|
| 300 |
+
|
| 301 |
+
if bbox is None:
|
| 302 |
+
bbox = [0, 0, im.shape[1], im.shape[0]]
|
| 303 |
+
bboxes = [bbox]
|
| 304 |
+
|
| 305 |
+
span_list, _ = split_textblock(im, show_process=show_process, shrink=False, recheck=True, discard=False, crop_ratio=0)
|
| 306 |
+
if span_list is None:
|
| 307 |
+
return [TextSpan(0, 0, im.shape[1], im.shape[0])]
|
| 308 |
+
# span_list, _ = shrink_span_list(im, span_list, shrink_vert_space=False)
|
| 309 |
+
|
| 310 |
+
for ii, span in enumerate(span_list):
|
| 311 |
+
left = span.left
|
| 312 |
+
right = span.right
|
| 313 |
+
if ii == 0:
|
| 314 |
+
span.left = 0
|
| 315 |
+
else:
|
| 316 |
+
span.left = span.top
|
| 317 |
+
if ii == len(span_list) - 1:
|
| 318 |
+
span.right = im.shape[0]
|
| 319 |
+
else:
|
| 320 |
+
span.right = span.bottom
|
| 321 |
+
span.top = imw - right
|
| 322 |
+
span.bottom = imw - left
|
| 323 |
+
span.height = span.bottom - span.top
|
| 324 |
+
span.width = span.right - span.left
|
| 325 |
+
|
| 326 |
+
return span_list
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def tessocr_img2text_linemode(img, span_list=None, combine_lines=True, show_process=False, gen_data=False, lang="comic6k", jpn_vert=False):
|
| 330 |
+
if jpn_vert:
|
| 331 |
+
lang = "jpn_vert"
|
| 332 |
+
img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
| 333 |
+
hig = img.shape[0]
|
| 334 |
+
wid = img.shape[1]
|
| 335 |
+
if hig * wid < 5:
|
| 336 |
+
return '', -1, -1
|
| 337 |
+
|
| 338 |
+
bw = 3
|
| 339 |
+
text = ''
|
| 340 |
+
alignment, vars = 0, (-1, -1)
|
| 341 |
+
if span_list is None:
|
| 342 |
+
span_list, vars = split_textblock(img, show_process=show_process)
|
| 343 |
+
_, maxspan = find_span(span_list, max)
|
| 344 |
+
maxh = bw*2 + maxspan.height
|
| 345 |
+
else:
|
| 346 |
+
maxh = max([s[2]-s[0] for s in span_list])
|
| 347 |
+
maxh = bw*2 + maxh
|
| 348 |
+
|
| 349 |
+
long_line = []
|
| 350 |
+
word_space = int(round(maxh / 8))
|
| 351 |
+
img = 255 - img
|
| 352 |
+
for ind, s in enumerate(span_list):
|
| 353 |
+
if isinstance(s, list):
|
| 354 |
+
im = img[s[0]: s[2], s[1]: s[3]]
|
| 355 |
+
else:
|
| 356 |
+
im = img[s.top: s.bottom, s.left: s.right]
|
| 357 |
+
|
| 358 |
+
hw1 = int(round((maxh - im.shape[0])/2))
|
| 359 |
+
hw2 = maxh - hw1 - im.shape[0]
|
| 360 |
+
dst = cv2.copyMakeBorder(im, hw1, hw2, word_space, word_space, cv2.BORDER_CONSTANT, None, value=[255, 255, 255])
|
| 361 |
+
|
| 362 |
+
if not combine_lines:
|
| 363 |
+
text += tessocr_img2text(dst, lang=lang) +'\n'
|
| 364 |
+
else:
|
| 365 |
+
long_line.append(dst)
|
| 366 |
+
if show_process:
|
| 367 |
+
cv2.imshow(str(ind), dst)
|
| 368 |
+
|
| 369 |
+
if combine_lines:
|
| 370 |
+
long_line = cv2.hconcat(long_line)
|
| 371 |
+
if jpn_vert:
|
| 372 |
+
long_line = cv2.rotate(long_line, cv2.ROTATE_90_CLOCKWISE)
|
| 373 |
+
if show_process:
|
| 374 |
+
cv2.namedWindow("long line:", cv2.WINDOW_NORMAL)
|
| 375 |
+
cv2.imshow("long line:", long_line)
|
| 376 |
+
if gen_data:
|
| 377 |
+
return long_line
|
| 378 |
+
res = tessocr_img2text(long_line, lang=lang)
|
| 379 |
+
mean_height = -1
|
| 380 |
+
if len(span_list) != 0:
|
| 381 |
+
if isinstance(span_list[0], list):
|
| 382 |
+
mean_height = np.mean(np.array([s[2]-s[0] for s in span_list]))
|
| 383 |
+
else:
|
| 384 |
+
mean_height = np.mean(np.array([s.height for s in span_list]))
|
| 385 |
+
alignment = 1 if vars[1] < vars[0] else 0
|
| 386 |
+
return res, mean_height, alignment
|
utils/stroke_width_calculator.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2, os, time
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def calculate_derivatives(gx, gy):
|
| 6 |
+
mag = np.sqrt(gx*gx + gy*gy)
|
| 7 |
+
if mag==0:
|
| 8 |
+
return False, -1, -1
|
| 9 |
+
else:
|
| 10 |
+
return True, gx / mag, gy / mag
|
| 11 |
+
|
| 12 |
+
def sw_calculator(mask, canny_img, gradient_x, gradient_y, show_process=False):
|
| 13 |
+
height, width = canny_img.shape[0], canny_img.shape[1]
|
| 14 |
+
|
| 15 |
+
if show_process:
|
| 16 |
+
drawborder = np.zeros((canny_img.shape[0], canny_img.shape[1], 3), dtype=np.uint8)
|
| 17 |
+
|
| 18 |
+
pnts = np.where(np.logical_and(canny_img != 0, mask!=0))
|
| 19 |
+
total_pnt_num = pnts[0].shape[0]
|
| 20 |
+
sample_pnt_num = 150
|
| 21 |
+
sample_step = total_pnt_num / sample_pnt_num if total_pnt_num > sample_pnt_num else 1
|
| 22 |
+
|
| 23 |
+
cur_pnt_ind = 0
|
| 24 |
+
ray_list = []
|
| 25 |
+
|
| 26 |
+
while cur_pnt_ind < total_pnt_num:
|
| 27 |
+
start_x, start_y = pnts[1][cur_pnt_ind], pnts[0][cur_pnt_ind]
|
| 28 |
+
ray_arr = [start_x, start_y, -1, -1, -1]
|
| 29 |
+
valid, dx, dy = calculate_derivatives(gradient_x[start_y][start_x], gradient_y[start_y][start_x])
|
| 30 |
+
|
| 31 |
+
if valid:
|
| 32 |
+
inc = 0.2
|
| 33 |
+
cur_x, cur_y = start_x + inc * dx, start_y + inc * dy
|
| 34 |
+
while (True):
|
| 35 |
+
tmp_curx, tmp_cury = int(cur_x), int(cur_y)
|
| 36 |
+
if tmp_curx < 0 or tmp_curx >= width or tmp_cury <= 0 or tmp_cury >= height:
|
| 37 |
+
break
|
| 38 |
+
if canny_img[tmp_cury][tmp_curx] == 0:
|
| 39 |
+
valid, dx_t, dy_t = calculate_derivatives(gradient_x[tmp_cury][tmp_curx], gradient_y[tmp_cury][tmp_curx])
|
| 40 |
+
if not valid:
|
| 41 |
+
break
|
| 42 |
+
if np.arccos(-dx * dx_t + -dy * dy_t) < np.pi / 2.0:
|
| 43 |
+
ray_arr[2] = tmp_curx
|
| 44 |
+
ray_arr[3] = tmp_cury
|
| 45 |
+
ray_arr[4] = np.sqrt((start_x - tmp_curx)**2 + (start_y - tmp_cury)**2)
|
| 46 |
+
break
|
| 47 |
+
cur_x += dx
|
| 48 |
+
cur_y += dy
|
| 49 |
+
if ray_arr[2] != -1:
|
| 50 |
+
ray_list.append(ray_arr)
|
| 51 |
+
if show_process:
|
| 52 |
+
drawborder = cv2.arrowedLine(drawborder, (ray_arr[0], ray_arr[1]), (ray_arr[2], ray_arr[3]),
|
| 53 |
+
(0, 255, 0), 1)
|
| 54 |
+
|
| 55 |
+
cur_pnt_ind += sample_step
|
| 56 |
+
cur_pnt_ind = int(round(cur_pnt_ind))
|
| 57 |
+
if show_process and len(ray_list) != 0:
|
| 58 |
+
ray_list.sort(key=lambda x: x[4])
|
| 59 |
+
cv2.imshow("border", drawborder)
|
| 60 |
+
cv2.imshow("cannyimg", canny_img)
|
| 61 |
+
cv2.waitKey(0)
|
| 62 |
+
return ray_list
|
| 63 |
+
|
| 64 |
+
def strokewidth_check(text_mask, labels, num_labels, stats, debug_type=0):
|
| 65 |
+
rays_width = []
|
| 66 |
+
height, width = text_mask.shape[0], text_mask.shape[1]
|
| 67 |
+
|
| 68 |
+
blur_img = cv2.dilate(text_mask ,(3,3),cv2.BORDER_DEFAULT)
|
| 69 |
+
|
| 70 |
+
# canny_img = cv2.Canny(cv2.dilate(text_mask, (3,3), 1), 170, 320, L2gradient=True, apertureSize=3)
|
| 71 |
+
|
| 72 |
+
_, canny_img = cv2.threshold(text_mask, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
|
| 73 |
+
blur2 = blur_img.astype(float) / 255
|
| 74 |
+
gradient_x = cv2.Scharr(blur2, ddepth=-1, dx=1, dy=0)
|
| 75 |
+
gradient_x = cv2.GaussianBlur(gradient_x ,(3, 3),cv2.BORDER_DEFAULT)
|
| 76 |
+
gradient_y = cv2.Scharr(blur2, ddepth=-1, dx=0, dy=1)
|
| 77 |
+
gradient_y = cv2.GaussianBlur(gradient_y ,(3, 3),cv2.BORDER_DEFAULT)
|
| 78 |
+
|
| 79 |
+
img_area = text_mask.shape[0] * text_mask.shape[1]
|
| 80 |
+
show_process = True if debug_type > 0 else False
|
| 81 |
+
for lab in range(num_labels):
|
| 82 |
+
stat = stats[lab]
|
| 83 |
+
if lab != 0 and stat[4] > img_area * 0.002:
|
| 84 |
+
x1, y1, x2, y2 = stat[0] - 2, stat[1] - 2, stat[0] + stat[2] + 2, stat[1] + stat[3] + 2
|
| 85 |
+
x1, x2 = max(x1, 0), min(x2, width)
|
| 86 |
+
y1, y2 = max(y1, 0), min(y2, height)
|
| 87 |
+
labcord = np.where(labels==lab)
|
| 88 |
+
labcord2 = (labcord[0] - y1, labcord[1] - x1)
|
| 89 |
+
text_roi = np.zeros((y2-y1, x2-x1), dtype=np.uint8)
|
| 90 |
+
text_roi[labcord2] = 255
|
| 91 |
+
text_roi = cv2.GaussianBlur(text_roi ,(3,3), cv2.BORDER_DEFAULT)
|
| 92 |
+
ray_list = sw_calculator(text_roi,
|
| 93 |
+
canny_img[y1: y2, x1: x2],
|
| 94 |
+
gradient_x[y1: y2, x1: x2],
|
| 95 |
+
gradient_y[y1: y2, x1: x2],
|
| 96 |
+
show_process=show_process)
|
| 97 |
+
if len(ray_list) != 0:
|
| 98 |
+
ray_list.sort(key=lambda x: x[4])
|
| 99 |
+
rays_width.append([int(lab), ray_list[int(len(ray_list)/2)][4]])
|
| 100 |
+
|
| 101 |
+
if len(rays_width) != 0:
|
| 102 |
+
rays_width = np.array(rays_width)
|
| 103 |
+
mean_width = np.mean(rays_width[:, 1])
|
| 104 |
+
ma = np.int0(rays_width[:, 0])
|
| 105 |
+
mean_area = np.mean(stats[ma][:, 4])
|
| 106 |
+
|
| 107 |
+
false_labels = np.where(rays_width[:, 1] > 2*mean_width)[0]
|
| 108 |
+
false_labels = rays_width[false_labels, 0].astype(np.int32)
|
| 109 |
+
for fl in false_labels:
|
| 110 |
+
if stats[fl][4] > 2 * mean_area:
|
| 111 |
+
text_mask[np.where(labels==fl)] = 0
|
| 112 |
+
return text_mask
|
| 113 |
+
|
utils/structures.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, List, ClassVar, Union, Any, Dict, Set
|
| 2 |
+
from dataclasses import dataclass, field, is_dataclass
|
| 3 |
+
import copy
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# decorator to wrap original __init__
|
| 10 |
+
# https://www.geeksforgeeks.org/creating-nested-dataclass-objects-in-python/
|
| 11 |
+
def nested_dataclass(*args, **dataclass_kwargs):
|
| 12 |
+
'''
|
| 13 |
+
nested dataclass support \n
|
| 14 |
+
also ignore extra arguments
|
| 15 |
+
'''
|
| 16 |
+
def wrapper(check_class):
|
| 17 |
+
|
| 18 |
+
# passing class to investigate
|
| 19 |
+
check_class = dataclass(check_class, **dataclass_kwargs)
|
| 20 |
+
o_init = check_class.__init__
|
| 21 |
+
|
| 22 |
+
def __init__(self, *args, **kwargs):
|
| 23 |
+
|
| 24 |
+
store_deprecated = 'deprecated_attributes' in self.__annotations__
|
| 25 |
+
deprecated = {}
|
| 26 |
+
for name in list(kwargs.keys()):
|
| 27 |
+
if name not in self.__annotations__:
|
| 28 |
+
# print(f'warning: type object \'{self.__class__.__name__}\' has no attribute {name}, might be loading from an older config')
|
| 29 |
+
val = kwargs.pop(name)
|
| 30 |
+
if store_deprecated:
|
| 31 |
+
deprecated[name] = val
|
| 32 |
+
continue
|
| 33 |
+
value = kwargs[name]
|
| 34 |
+
# getting field type
|
| 35 |
+
ft = check_class.__annotations__.get(name, None)
|
| 36 |
+
|
| 37 |
+
if is_dataclass(ft) and isinstance(value, dict):
|
| 38 |
+
obj = ft(**value)
|
| 39 |
+
kwargs[name]= obj
|
| 40 |
+
|
| 41 |
+
if len(deprecated) > 0:
|
| 42 |
+
kwargs['deprecated_attributes'] = deprecated
|
| 43 |
+
|
| 44 |
+
o_init(self, *args, **kwargs)
|
| 45 |
+
check_class.__init__=__init__
|
| 46 |
+
|
| 47 |
+
return check_class
|
| 48 |
+
|
| 49 |
+
return wrapper(args[0]) if args else wrapper
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@dataclass
|
| 53 |
+
class Config:
|
| 54 |
+
|
| 55 |
+
def update(self, key: str, value):
|
| 56 |
+
assert key in self.__annotations__, f'type object \'{self.__class__.__name__}\' has no attribute {key}'
|
| 57 |
+
self.__setattr__(key, value)
|
| 58 |
+
|
| 59 |
+
@classmethod
|
| 60 |
+
def annotations_set(cls):
|
| 61 |
+
return set(list(cls.__annotations__))
|
| 62 |
+
|
| 63 |
+
def __getitem__(self, key: str):
|
| 64 |
+
assert key in self.__annotations__, f'type object \'{self.__class__.__name__}\' has no attribute {key}'
|
| 65 |
+
return self.__getattribute__(key)
|
| 66 |
+
|
| 67 |
+
def __setitem__(self, key: str, value):
|
| 68 |
+
self.__setattr__(key, value)
|
| 69 |
+
|
| 70 |
+
@classmethod
|
| 71 |
+
def params(cls):
|
| 72 |
+
return cls.__annotations__
|
| 73 |
+
|
| 74 |
+
def merge(self, target):
|
| 75 |
+
tgt_keys = target.annotations_set()
|
| 76 |
+
for key in tgt_keys:
|
| 77 |
+
self.update(key, target[key])
|
| 78 |
+
|
| 79 |
+
def copy(self):
|
| 80 |
+
return copy.deepcopy(self)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
MODULE_PATH = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
| 84 |
+
BASE_PATH = os.path.dirname(MODULE_PATH)
|
utils/text_layout.py
ADDED
|
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from .imgproc_utils import rotate_image
|
| 5 |
+
from .textblock import TextBlock, TextAlignment
|
| 6 |
+
|
| 7 |
+
class Line:
|
| 8 |
+
|
| 9 |
+
def __init__(self, text: str = '', pos_x: int = 0, pos_y: int = 0, length: float = 0, spacing: int = 0) -> None:
|
| 10 |
+
self.text = text
|
| 11 |
+
self.pos_x = pos_x
|
| 12 |
+
self.pos_y = pos_y
|
| 13 |
+
self.length = int(length)
|
| 14 |
+
self.num_words = 0
|
| 15 |
+
if text:
|
| 16 |
+
self.num_words += 1
|
| 17 |
+
self.spacing = 0
|
| 18 |
+
self.add_spacing(spacing)
|
| 19 |
+
|
| 20 |
+
def append_right(self, word: str, w_len: int, delimiter: str = ''):
|
| 21 |
+
self.text = self.text + delimiter + word
|
| 22 |
+
if word:
|
| 23 |
+
self.num_words += 1
|
| 24 |
+
self.length += w_len
|
| 25 |
+
|
| 26 |
+
def append_left(self, word: str, w_len: int, delimiter: str = ''):
|
| 27 |
+
self.text = word + delimiter + self.text
|
| 28 |
+
if word:
|
| 29 |
+
self.num_words += 1
|
| 30 |
+
self.length += w_len
|
| 31 |
+
|
| 32 |
+
def add_spacing(self, spacing: int):
|
| 33 |
+
self.spacing = spacing
|
| 34 |
+
self.pos_x -= spacing
|
| 35 |
+
self.length += 2 * spacing
|
| 36 |
+
|
| 37 |
+
def strip_spacing(self):
|
| 38 |
+
self.length -= self.spacing * 2
|
| 39 |
+
self.pos_x += self.spacing
|
| 40 |
+
self.spacing = 0
|
| 41 |
+
|
| 42 |
+
def line_is_valid(line: Line, new_len: int, delimiter_len, max_width, words_length, srcline_wlist, line_no: int, line_height, ref_src_lines: bool = False):
|
| 43 |
+
if ref_src_lines:
|
| 44 |
+
# if line_no >= 0 and line_no < len(srcline_wlist):
|
| 45 |
+
# _max_width = min(srcline_wlist[line_no], max_width)
|
| 46 |
+
# else:
|
| 47 |
+
# _max_width = max_width
|
| 48 |
+
if line_no >= 0 and line_no < len(srcline_wlist):
|
| 49 |
+
_max_width = srcline_wlist[line_no] * words_length
|
| 50 |
+
else:
|
| 51 |
+
_max_width = np.inf
|
| 52 |
+
_max_width = max(srcline_wlist) * words_length
|
| 53 |
+
_max_width = _max_width + delimiter_len * line.num_words
|
| 54 |
+
max_width = min(max_width, _max_width)
|
| 55 |
+
|
| 56 |
+
if new_len < max_width:
|
| 57 |
+
return True
|
| 58 |
+
else:
|
| 59 |
+
if line.length / max_width < max_width / new_len:
|
| 60 |
+
return True
|
| 61 |
+
else:
|
| 62 |
+
return False
|
| 63 |
+
|
| 64 |
+
def layout_lines_aligncenter(
|
| 65 |
+
blk: TextBlock,
|
| 66 |
+
mask: np.ndarray,
|
| 67 |
+
words: List[str],
|
| 68 |
+
centroid: List[int],
|
| 69 |
+
wl_list: List[int],
|
| 70 |
+
delimiter_len: int,
|
| 71 |
+
line_height: int,
|
| 72 |
+
spacing: int = 0,
|
| 73 |
+
delimiter: str = ' ',
|
| 74 |
+
max_central_width: float = np.inf,
|
| 75 |
+
word_break: bool = False,
|
| 76 |
+
ref_src_lines = False,
|
| 77 |
+
srcline_wlist=None,
|
| 78 |
+
start_from_top=False
|
| 79 |
+
)->List[Line]:
|
| 80 |
+
|
| 81 |
+
lh_pad = 0
|
| 82 |
+
if blk.line_spacing > 1:
|
| 83 |
+
lh_pad = int(np.ceil(line_height - line_height / blk.line_spacing))
|
| 84 |
+
|
| 85 |
+
centroid_x, centroid_y = centroid
|
| 86 |
+
adjust_x = adjust_y = 0
|
| 87 |
+
|
| 88 |
+
border_thr = 220
|
| 89 |
+
|
| 90 |
+
# layout the central line, the center word is approximately aligned with the centroid of the mask
|
| 91 |
+
num_words = len(words)
|
| 92 |
+
len_left, len_right = [], []
|
| 93 |
+
wlst_left, wlst_right = [], []
|
| 94 |
+
sum_left, sum_right = 0, 0
|
| 95 |
+
words_length = sum(wl_list)
|
| 96 |
+
if num_words > 1:
|
| 97 |
+
wl_array = np.array(wl_list, dtype=np.float64)
|
| 98 |
+
wl_cumsums = np.cumsum(wl_array)
|
| 99 |
+
wl_cumsums = wl_cumsums - wl_cumsums[-1] / 2 - wl_array / 2
|
| 100 |
+
central_index = np.argmin(np.abs(wl_cumsums))
|
| 101 |
+
|
| 102 |
+
if central_index > 0:
|
| 103 |
+
wlst_left = words[:central_index]
|
| 104 |
+
len_left = wl_list[:central_index]
|
| 105 |
+
sum_left = np.sum(len_left)
|
| 106 |
+
if central_index < num_words - 1:
|
| 107 |
+
wlst_right = words[central_index + 1:]
|
| 108 |
+
len_right = wl_list[central_index + 1:]
|
| 109 |
+
sum_right = np.sum(len_right)
|
| 110 |
+
else:
|
| 111 |
+
central_index = 0
|
| 112 |
+
|
| 113 |
+
pos_y = centroid_y - line_height // 2
|
| 114 |
+
pos_x = centroid_x - wl_list[central_index] // 2
|
| 115 |
+
|
| 116 |
+
bh, bw = mask.shape[:2]
|
| 117 |
+
central_line = Line(words[central_index], pos_x, pos_y, wl_list[central_index], spacing)
|
| 118 |
+
line_bottom = pos_y + line_height
|
| 119 |
+
while (sum_left > 0 or sum_right > 0) and not start_from_top:
|
| 120 |
+
left_valid, right_valid = False, False
|
| 121 |
+
|
| 122 |
+
if sum_left > 0:
|
| 123 |
+
new_len_l = central_line.length + len_left[-1] + delimiter_len
|
| 124 |
+
new_x_l = centroid_x - new_len_l // 2
|
| 125 |
+
new_r_l = new_x_l + new_len_l
|
| 126 |
+
if (new_x_l > 0 and new_r_l < bw):
|
| 127 |
+
if mask[pos_y: line_bottom - lh_pad, new_x_l].mean() > border_thr and \
|
| 128 |
+
mask[pos_y: line_bottom - lh_pad, new_r_l].mean() > border_thr:
|
| 129 |
+
left_valid = True
|
| 130 |
+
if sum_right > 0:
|
| 131 |
+
new_len_r = central_line.length + len_right[0] + delimiter_len
|
| 132 |
+
new_x_r = centroid_x - new_len_r // 2 - line_height // 2
|
| 133 |
+
new_r_r = centroid_x + new_len_r // 2 + line_height // 2
|
| 134 |
+
if (new_x_r > 0 and new_r_r < bw):
|
| 135 |
+
if mask[pos_y: line_bottom - lh_pad, new_x_r].mean() > border_thr and \
|
| 136 |
+
mask[pos_y: line_bottom - lh_pad, new_r_r].mean() > border_thr:
|
| 137 |
+
right_valid = True
|
| 138 |
+
|
| 139 |
+
insert_left = False
|
| 140 |
+
if left_valid and right_valid:
|
| 141 |
+
if sum_left > sum_right:
|
| 142 |
+
insert_left = True
|
| 143 |
+
elif left_valid:
|
| 144 |
+
insert_left = True
|
| 145 |
+
elif not right_valid:
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
if insert_left:
|
| 149 |
+
new_len = central_line.length + len_left[-1] + delimiter_len
|
| 150 |
+
else:
|
| 151 |
+
new_len = central_line.length + len_right[0] + delimiter_len
|
| 152 |
+
|
| 153 |
+
line_valid = line_is_valid(central_line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, -1, line_height, ref_src_lines)
|
| 154 |
+
if ref_src_lines and not line_valid and len(srcline_wlist) == 1:
|
| 155 |
+
if new_len < max_central_width:
|
| 156 |
+
line_valid = True
|
| 157 |
+
if not line_valid:
|
| 158 |
+
break
|
| 159 |
+
|
| 160 |
+
if insert_left:
|
| 161 |
+
central_line.append_left(wlst_left.pop(-1), len_left[-1] + delimiter_len, delimiter)
|
| 162 |
+
sum_left -= len_left.pop(-1)
|
| 163 |
+
central_line.pos_x = new_x_l
|
| 164 |
+
else:
|
| 165 |
+
central_line.append_right(wlst_right.pop(0), len_right[0] + delimiter_len, delimiter)
|
| 166 |
+
sum_right -= len_right.pop(0)
|
| 167 |
+
central_line.pos_x = new_x_r
|
| 168 |
+
|
| 169 |
+
line_right_no = line_left_no = 0
|
| 170 |
+
if ref_src_lines:
|
| 171 |
+
nl = len(srcline_wlist)
|
| 172 |
+
if nl % 2 == 0:
|
| 173 |
+
line_right_no = nl // 2
|
| 174 |
+
line_left_no = nl // 2 - 1
|
| 175 |
+
else:
|
| 176 |
+
line_right_no = nl // 2 + 1
|
| 177 |
+
line_left_no = nl // 2 - 1
|
| 178 |
+
|
| 179 |
+
if not start_from_top:
|
| 180 |
+
central_line.strip_spacing()
|
| 181 |
+
lines = [central_line]
|
| 182 |
+
else:
|
| 183 |
+
lines = []
|
| 184 |
+
sum_right = sum(wl_list)
|
| 185 |
+
sum_left = 0
|
| 186 |
+
wlst_right = words
|
| 187 |
+
len_right = wl_list
|
| 188 |
+
line_right_no = 0
|
| 189 |
+
|
| 190 |
+
# layout bottom half
|
| 191 |
+
if sum_right > 0:
|
| 192 |
+
w, wl = wlst_right.pop(0), len_right.pop(0)
|
| 193 |
+
pos_x = centroid_x - wl // 2
|
| 194 |
+
if start_from_top:
|
| 195 |
+
pos_y = centroid_y - int(blk.bounding_rect()[3] / 2)
|
| 196 |
+
else:
|
| 197 |
+
pos_y = centroid_y + line_height // 2
|
| 198 |
+
pos_y = max(0, min(pos_y, mask.shape[0] - 1))
|
| 199 |
+
top_mean = mask[pos_y, :].mean()
|
| 200 |
+
x_mean = mask.mean(axis=1)
|
| 201 |
+
base_mean = x_mean.max() / 2
|
| 202 |
+
if top_mean < base_mean:
|
| 203 |
+
available_y = np.where(
|
| 204 |
+
x_mean[pos_y:] > base_mean
|
| 205 |
+
)[0]
|
| 206 |
+
if len(available_y) > 0:
|
| 207 |
+
adjust_y = min(available_y[0], line_height)
|
| 208 |
+
pos_y = pos_y + adjust_y
|
| 209 |
+
line_bottom = pos_y + line_height
|
| 210 |
+
line = Line(w, pos_x, pos_y, wl, spacing)
|
| 211 |
+
lines.append(line)
|
| 212 |
+
sum_right -= wl
|
| 213 |
+
while sum_right > 0:
|
| 214 |
+
w, wl = wlst_right.pop(0), len_right.pop(0)
|
| 215 |
+
sum_right -= wl
|
| 216 |
+
new_len = line.length + wl + delimiter_len
|
| 217 |
+
new_x = centroid_x - new_len // 2 - line_height // 2
|
| 218 |
+
right_x = new_x + new_len + line_height // 2
|
| 219 |
+
if new_x < 0 or right_x >= bw:
|
| 220 |
+
line_valid = False
|
| 221 |
+
elif mask[pos_y: line_bottom - lh_pad, new_x].mean() < border_thr or\
|
| 222 |
+
mask[pos_y: line_bottom - lh_pad, right_x].mean() < border_thr:
|
| 223 |
+
line_valid = False
|
| 224 |
+
if ref_src_lines and (len(wl_list) == 1 or line_right_no + 1 >= len(srcline_wlist)) and \
|
| 225 |
+
line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_right_no, line_height, ref_src_lines):
|
| 226 |
+
line_valid = True
|
| 227 |
+
else:
|
| 228 |
+
line_valid = True
|
| 229 |
+
if line_valid:
|
| 230 |
+
line.append_right(w, wl+delimiter_len, delimiter)
|
| 231 |
+
line.pos_x = new_x
|
| 232 |
+
line_valid = line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_right_no, line_height, ref_src_lines)
|
| 233 |
+
if not line_valid:
|
| 234 |
+
if sum_right > 0:
|
| 235 |
+
w, wl = wlst_right.pop(0), len_right.pop(0)
|
| 236 |
+
sum_right -= wl
|
| 237 |
+
else:
|
| 238 |
+
line.strip_spacing()
|
| 239 |
+
break
|
| 240 |
+
|
| 241 |
+
if not line_valid:
|
| 242 |
+
pos_x = centroid_x - wl // 2
|
| 243 |
+
pos_y = line_bottom
|
| 244 |
+
line_bottom += line_height
|
| 245 |
+
line.strip_spacing()
|
| 246 |
+
line = Line(w, pos_x, pos_y, wl, spacing)
|
| 247 |
+
lines.append(line)
|
| 248 |
+
line_right_no += 1
|
| 249 |
+
|
| 250 |
+
# layout top half
|
| 251 |
+
if sum_left > 0:
|
| 252 |
+
w, wl = wlst_left.pop(-1), len_left.pop(-1)
|
| 253 |
+
pos_x = centroid_x - wl // 2
|
| 254 |
+
pos_y = centroid_y - line_height // 2 - line_height
|
| 255 |
+
pos_y = max(0, min(pos_y, mask.shape[0] - 1))
|
| 256 |
+
line_bottom = pos_y + line_height
|
| 257 |
+
line = Line(w, pos_x, pos_y, wl, spacing)
|
| 258 |
+
lines.insert(0, line)
|
| 259 |
+
sum_left -= wl
|
| 260 |
+
while sum_left > 0:
|
| 261 |
+
w, wl = wlst_left.pop(-1), len_left.pop(-1)
|
| 262 |
+
sum_left -= wl
|
| 263 |
+
new_len = line.length + wl + delimiter_len
|
| 264 |
+
new_x = centroid_x - new_len // 2 - line_height // 2
|
| 265 |
+
right_x = new_x + new_len + line_height // 2
|
| 266 |
+
if new_x <= 0 or right_x >= bw:
|
| 267 |
+
line_valid = False
|
| 268 |
+
elif mask[pos_y: line_bottom - lh_pad, new_x].mean() < border_thr or\
|
| 269 |
+
mask[pos_y: line_bottom - lh_pad, right_x].mean() < border_thr:
|
| 270 |
+
line_valid = False
|
| 271 |
+
if ref_src_lines and line_left_no - 1 < 0 and \
|
| 272 |
+
line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_left_no, line_height, ref_src_lines):
|
| 273 |
+
line_valid = True
|
| 274 |
+
else:
|
| 275 |
+
line_valid = True
|
| 276 |
+
if line_valid:
|
| 277 |
+
line.append_left(w, wl+delimiter_len, delimiter)
|
| 278 |
+
line.pos_x = new_x
|
| 279 |
+
line_valid = line_is_valid(line, new_len, delimiter_len, max_central_width, words_length, srcline_wlist, line_left_no, line_height, ref_src_lines)
|
| 280 |
+
if not line_valid:
|
| 281 |
+
if sum_left > 0:
|
| 282 |
+
w, wl = wlst_left.pop(-1), len_left.pop(-1)
|
| 283 |
+
sum_left -= wl
|
| 284 |
+
else:
|
| 285 |
+
line.strip_spacing()
|
| 286 |
+
break
|
| 287 |
+
|
| 288 |
+
if not line_valid :
|
| 289 |
+
pos_x = centroid_x - wl // 2
|
| 290 |
+
pos_y -= line_height
|
| 291 |
+
line_bottom = pos_y + line_height
|
| 292 |
+
line.strip_spacing()
|
| 293 |
+
line = Line(w, pos_x, pos_y, wl, spacing)
|
| 294 |
+
lines.insert(0, line)
|
| 295 |
+
line_left_no -= 1
|
| 296 |
+
|
| 297 |
+
return lines, (adjust_x, adjust_y)
|
| 298 |
+
|
| 299 |
+
def layout_lines_alignside(
|
| 300 |
+
blk: TextBlock,
|
| 301 |
+
mask: np.ndarray,
|
| 302 |
+
words: List[str],
|
| 303 |
+
origin: List[int],
|
| 304 |
+
wl_list: List[int],
|
| 305 |
+
delimiter_len: int,
|
| 306 |
+
line_height: int,
|
| 307 |
+
spacing: int = 0,
|
| 308 |
+
delimiter: str = ' ',
|
| 309 |
+
word_break: bool = False,
|
| 310 |
+
max_width: int = np.inf,
|
| 311 |
+
ref_src_lines = False,
|
| 312 |
+
srcline_wlist=None,
|
| 313 |
+
)->List[Line]:
|
| 314 |
+
|
| 315 |
+
align_right = blk.fontformat.alignment == TextAlignment.Right
|
| 316 |
+
|
| 317 |
+
ox, oy = origin
|
| 318 |
+
bh, bw = mask.shape[:2]
|
| 319 |
+
num_words = len(words)
|
| 320 |
+
blk_rect = blk.bounding_rect()
|
| 321 |
+
blk_width = blk_rect[2]
|
| 322 |
+
lines = []
|
| 323 |
+
words_length = sum(wl_list)
|
| 324 |
+
|
| 325 |
+
lh_pad = 0
|
| 326 |
+
if blk.line_spacing > 1:
|
| 327 |
+
lh_pad = int(np.ceil(line_height - line_height / blk.line_spacing))
|
| 328 |
+
|
| 329 |
+
if num_words > 0:
|
| 330 |
+
sum_right = np.array(wl_list).sum()
|
| 331 |
+
w, wl = words.pop(0), wl_list.pop(0)
|
| 332 |
+
line = Line(w, ox, oy, wl)
|
| 333 |
+
lines.append(line)
|
| 334 |
+
sum_right -= wl
|
| 335 |
+
line_bottom = oy + line_height
|
| 336 |
+
pos_y = oy
|
| 337 |
+
line_id = 0
|
| 338 |
+
while sum_right > 0:
|
| 339 |
+
w, wl = words.pop(0), wl_list.pop(0)
|
| 340 |
+
sum_right -= wl
|
| 341 |
+
new_len = line.length + wl + delimiter_len
|
| 342 |
+
if align_right:
|
| 343 |
+
new_x = ox + blk_width - new_len - line_height // 2
|
| 344 |
+
else:
|
| 345 |
+
new_x = ox + new_len + line_height // 2
|
| 346 |
+
line_valid = False
|
| 347 |
+
if new_x < bw and new_x > 0:
|
| 348 |
+
if mask[np.clip(pos_y, 0, bh - 1): np.clip(line_bottom - lh_pad, 0, bh), new_x].mean() > 240:
|
| 349 |
+
line_valid = True
|
| 350 |
+
else:
|
| 351 |
+
if ref_src_lines and line_id + 1 >= len(srcline_wlist) and line_is_valid(line, new_len, delimiter_len, max_width, words_length, srcline_wlist, line_id, line_height, ref_src_lines):
|
| 352 |
+
line_valid = True
|
| 353 |
+
if line_valid:
|
| 354 |
+
line_valid = line_is_valid(line, new_len, delimiter_len, max_width, words_length, srcline_wlist, line_id, line_height, ref_src_lines)
|
| 355 |
+
if line_valid:
|
| 356 |
+
line.append_right(w, wl+delimiter_len, delimiter)
|
| 357 |
+
else:
|
| 358 |
+
pos_y = line_bottom
|
| 359 |
+
line_bottom += line_height
|
| 360 |
+
line = Line(w, ox, pos_y, wl)
|
| 361 |
+
line_id += 1
|
| 362 |
+
lines.append(line)
|
| 363 |
+
return lines, (0, 0)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def layout_text(
|
| 368 |
+
blk: TextBlock,
|
| 369 |
+
mask: np.ndarray,
|
| 370 |
+
mask_xyxy: List,
|
| 371 |
+
centroid: List,
|
| 372 |
+
words: List[str],
|
| 373 |
+
wl_list: List[int],
|
| 374 |
+
delimiter: str,
|
| 375 |
+
delimiter_len: int,
|
| 376 |
+
line_height: int,
|
| 377 |
+
spacing: int = 0,
|
| 378 |
+
max_central_width=np.inf,
|
| 379 |
+
src_is_cjk=False,
|
| 380 |
+
tgt_is_cjk=False,
|
| 381 |
+
ref_src_lines = False
|
| 382 |
+
) -> Tuple[str, List]:
|
| 383 |
+
|
| 384 |
+
angle = blk.angle
|
| 385 |
+
alignment = blk.alignment
|
| 386 |
+
|
| 387 |
+
start_from_top = False
|
| 388 |
+
srcline_wlist = None
|
| 389 |
+
|
| 390 |
+
if ref_src_lines:
|
| 391 |
+
srcline_wlist, srcline_width = blk.normalizd_width_list(normalize=False)
|
| 392 |
+
# tgtline_width = sum(wl_list) + delimiter_len * max(len(wl_list) - 1, 0)
|
| 393 |
+
# if tgtline_width < srcline_width:
|
| 394 |
+
# min_bbox = blk.min_rect(rotate_back=True)[0]
|
| 395 |
+
# x1, y1 = min_bbox[0]
|
| 396 |
+
# x2, y2 = min_bbox[2]
|
| 397 |
+
# w = x2 - x1
|
| 398 |
+
# max_central_width = min(max_central_width, w)
|
| 399 |
+
# pass
|
| 400 |
+
|
| 401 |
+
if alignment == TextAlignment.Center and \
|
| 402 |
+
len(srcline_wlist) > 1:
|
| 403 |
+
if len(srcline_wlist) == 2:
|
| 404 |
+
start_from_top = True
|
| 405 |
+
else:
|
| 406 |
+
nw = len(srcline_wlist)
|
| 407 |
+
# nl = min(nw // 2, 2)
|
| 408 |
+
nl = 1
|
| 409 |
+
sum_top = sum(srcline_wlist[:nl])
|
| 410 |
+
sum_btn = sum(srcline_wlist[-nl:])
|
| 411 |
+
start_from_top = sum_top / sum_btn > 1.2 and srcline_wlist[0] / max(srcline_wlist) > 0.9
|
| 412 |
+
|
| 413 |
+
srcline_wlist = np.array(srcline_wlist) / srcline_width
|
| 414 |
+
srcline_wlist = srcline_wlist.tolist()
|
| 415 |
+
# line_height = min((blk.detected_font_size), line_height)
|
| 416 |
+
|
| 417 |
+
# if ref_src_lines:
|
| 418 |
+
# mask = np.ones_like(mask) * 255
|
| 419 |
+
|
| 420 |
+
if max_central_width == np.inf:
|
| 421 |
+
max_central_width = mask.shape[1]
|
| 422 |
+
|
| 423 |
+
centroid_x, centroid_y = centroid
|
| 424 |
+
center_x = mask_xyxy[0] + centroid_x
|
| 425 |
+
center_y = mask_xyxy[1] + centroid_y
|
| 426 |
+
shifted_x, shifted_y = 0, 0
|
| 427 |
+
if abs(angle) > 0:
|
| 428 |
+
|
| 429 |
+
old_h, old_w = mask.shape[:2]
|
| 430 |
+
old_origin = (old_w // 2, old_h // 2)
|
| 431 |
+
rel_cx, rel_cy = centroid[0] - old_origin[0], centroid[1] - old_origin[1]
|
| 432 |
+
|
| 433 |
+
mask = rotate_image(mask, angle)
|
| 434 |
+
rad = np.deg2rad(angle)
|
| 435 |
+
r_sin, r_cos = np.sin(rad), np.cos(rad)
|
| 436 |
+
new_rel_cy = -rel_cx * r_sin + rel_cy * r_cos
|
| 437 |
+
new_rel_cx = rel_cy * r_sin + rel_cx * r_cos
|
| 438 |
+
|
| 439 |
+
shifted_x, shifted_y = new_rel_cx - rel_cx, new_rel_cy - rel_cy
|
| 440 |
+
|
| 441 |
+
new_h, new_w = mask.shape[:2]
|
| 442 |
+
new_origin = (new_w // 2, new_h // 2)
|
| 443 |
+
new_cx, new_cy = new_origin[0] + new_rel_cx, new_origin[1] + new_rel_cy
|
| 444 |
+
centroid = [int(new_cx), int(new_cy)]
|
| 445 |
+
|
| 446 |
+
if alignment == TextAlignment.Center:
|
| 447 |
+
lines, adjust_xy = layout_lines_aligncenter(blk, mask, words, centroid, wl_list, delimiter_len, line_height, spacing, delimiter,
|
| 448 |
+
max_central_width, ref_src_lines=ref_src_lines, srcline_wlist=srcline_wlist,
|
| 449 |
+
start_from_top=start_from_top)
|
| 450 |
+
else:
|
| 451 |
+
lines, adjust_xy = layout_lines_alignside(blk, mask, words, centroid, wl_list, delimiter_len, line_height, spacing, delimiter, False, max_central_width,
|
| 452 |
+
ref_src_lines=ref_src_lines, srcline_wlist=srcline_wlist)
|
| 453 |
+
|
| 454 |
+
concated_text = []
|
| 455 |
+
pos_x_lst, pos_right_lst = [], []
|
| 456 |
+
for line in lines:
|
| 457 |
+
pos_x_lst.append(line.pos_x)
|
| 458 |
+
pos_right_lst.append(max(line.pos_x, 0) + line.length)
|
| 459 |
+
concated_text.append(line.text)
|
| 460 |
+
concated_text = '\n'.join(concated_text)
|
| 461 |
+
|
| 462 |
+
pos_x_lst = np.array(pos_x_lst)
|
| 463 |
+
pos_right_lst = np.array(pos_right_lst)
|
| 464 |
+
canvas_l, canvas_r = pos_x_lst.min(), pos_right_lst.max()
|
| 465 |
+
canvas_t, canvas_b = lines[0].pos_y, lines[-1].pos_y + line_height
|
| 466 |
+
|
| 467 |
+
canvas_h = int(canvas_b - canvas_t)
|
| 468 |
+
canvas_w = int(canvas_r - canvas_l)
|
| 469 |
+
|
| 470 |
+
if alignment == 1:
|
| 471 |
+
abs_x = int(round(center_x - canvas_w / 2))
|
| 472 |
+
abs_y = int(round(center_y - canvas_h / 2))
|
| 473 |
+
else:
|
| 474 |
+
abs_x = shifted_x
|
| 475 |
+
abs_y = shifted_y
|
| 476 |
+
|
| 477 |
+
return concated_text, [abs_x, abs_y, canvas_w, canvas_h], start_from_top, adjust_xy
|
utils/text_processing.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple
|
| 2 |
+
import json
|
| 3 |
+
import os.path as osp
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
HALF2FULL = {i: i + 0xFEE0 for i in range(0x21, 0x7F)}
|
| 7 |
+
HALF2FULL[0x20] = 0x3000
|
| 8 |
+
|
| 9 |
+
FULL2HALF = dict((i + 0xFEE0, i) for i in range(0x21, 0x7F))
|
| 10 |
+
FULL2HALF[0x3000] = 0x20
|
| 11 |
+
FULL2HALF[0x3002] = 0x2E
|
| 12 |
+
|
| 13 |
+
LANGSET_CJK = {'简体中文', '繁體中文', '日本語'}
|
| 14 |
+
LANGSET_CH = {'简体中文', '繁體中文'}
|
| 15 |
+
|
| 16 |
+
PUNSET_RIGHT_ENG = {'.', '?', '!', ':', ';', ')', '}', "\""}
|
| 17 |
+
PUNCTUATION_L = {'「', '『', '【', '《', '〈', '〔', '[', '{', '(', '(', '[', '{', '“', '‘'}
|
| 18 |
+
|
| 19 |
+
PKUSEG_PUNCSET = {' ', '.', ' '}
|
| 20 |
+
PKUSEGPATH = r'data/pkusegscores.json'
|
| 21 |
+
PKUSEGSCORES = None
|
| 22 |
+
CHSEG = None
|
| 23 |
+
|
| 24 |
+
def full_len(s: str):
|
| 25 |
+
"""
|
| 26 |
+
Convert all ASCII characters to their full-width counterpart.
|
| 27 |
+
https://stackoverflow.com/questions/2422177/python-how-can-i-replace-full-width-characters-with-half-width-characters
|
| 28 |
+
"""
|
| 29 |
+
return s.translate(HALF2FULL)
|
| 30 |
+
|
| 31 |
+
def half_len(s):
|
| 32 |
+
'''
|
| 33 |
+
Convert full-width characters to ASCII counterpart
|
| 34 |
+
'''
|
| 35 |
+
return s.translate(FULL2HALF)
|
| 36 |
+
|
| 37 |
+
def seg_to_chars(text: str) -> List[str]:
|
| 38 |
+
text = text.replace('\n', '')
|
| 39 |
+
return [c for c in text]
|
| 40 |
+
|
| 41 |
+
def seg_eng(text: str) -> List[str]:
|
| 42 |
+
text = text.replace(' ', ' ').replace(' .', '.').replace('\n', ' ')
|
| 43 |
+
processed_text = ''
|
| 44 |
+
|
| 45 |
+
# dumb way to insure spaces between words
|
| 46 |
+
text_len = len(text)
|
| 47 |
+
for ii, c in enumerate(text):
|
| 48 |
+
if c in PUNSET_RIGHT_ENG and ii < text_len - 1:
|
| 49 |
+
next_c = text[ii + 1]
|
| 50 |
+
if next_c.isalpha() or next_c.isnumeric():
|
| 51 |
+
processed_text += c + ' '
|
| 52 |
+
else:
|
| 53 |
+
processed_text += c
|
| 54 |
+
else:
|
| 55 |
+
processed_text += c
|
| 56 |
+
|
| 57 |
+
word_list = processed_text.split(' ')
|
| 58 |
+
word_num = len(word_list)
|
| 59 |
+
if word_num <= 1:
|
| 60 |
+
return word_list
|
| 61 |
+
|
| 62 |
+
words = []
|
| 63 |
+
skip_next = False
|
| 64 |
+
for ii, word in enumerate(word_list):
|
| 65 |
+
if skip_next:
|
| 66 |
+
skip_next = False
|
| 67 |
+
continue
|
| 68 |
+
if len(word) < 3:
|
| 69 |
+
append_left, append_right = False, False
|
| 70 |
+
len_word, len_next, len_prev = len(word), -1, -1
|
| 71 |
+
if ii < word_num - 1:
|
| 72 |
+
len_next = len(word_list[ii + 1])
|
| 73 |
+
if ii > 0:
|
| 74 |
+
len_prev = len(words[-1])
|
| 75 |
+
cond_next = (len_word == 2 and len_next <= 4) or len_word == 1
|
| 76 |
+
cond_prev = (len_word == 2 and len_prev <= 4) or len_word == 1
|
| 77 |
+
if len_next > 0 and len_prev > 0:
|
| 78 |
+
if len_next < len_prev:
|
| 79 |
+
append_right = cond_next
|
| 80 |
+
else:
|
| 81 |
+
append_left = cond_prev
|
| 82 |
+
elif len_next > 0:
|
| 83 |
+
append_right = cond_next
|
| 84 |
+
elif len_prev > 0:
|
| 85 |
+
append_left = cond_prev
|
| 86 |
+
|
| 87 |
+
if append_left:
|
| 88 |
+
words[-1] = words[-1] + ' ' + word
|
| 89 |
+
elif append_right:
|
| 90 |
+
words.append(word + ' ' + word_list[ii + 1])
|
| 91 |
+
skip_next = True
|
| 92 |
+
else:
|
| 93 |
+
words.append(word)
|
| 94 |
+
continue
|
| 95 |
+
words.append(word)
|
| 96 |
+
return words
|
| 97 |
+
|
| 98 |
+
def _seg_ch_pkg(text: str) -> List[str]:
|
| 99 |
+
|
| 100 |
+
if text == ' ':
|
| 101 |
+
return [' ']
|
| 102 |
+
elif text == '':
|
| 103 |
+
return []
|
| 104 |
+
|
| 105 |
+
segments = CHSEG.cut(text)
|
| 106 |
+
num_segments = len(segments)
|
| 107 |
+
if num_segments == 0:
|
| 108 |
+
return []
|
| 109 |
+
if num_segments == 1:
|
| 110 |
+
return [segments[0][0]]
|
| 111 |
+
|
| 112 |
+
words = []
|
| 113 |
+
tags = []
|
| 114 |
+
max_concat_len = 4
|
| 115 |
+
skip_next = False
|
| 116 |
+
try:
|
| 117 |
+
for ii, (word, tag) in enumerate(segments):
|
| 118 |
+
if skip_next:
|
| 119 |
+
skip_next = False
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
len_word, len_next, len_prev = len(word), -1, -1
|
| 123 |
+
next_valid, prev_valid = False, False
|
| 124 |
+
word_next, tag_next = '', ''
|
| 125 |
+
word_prev, tag_prev = '', ''
|
| 126 |
+
score_next, score_prev = 0, 0
|
| 127 |
+
if ii < num_segments - 1:
|
| 128 |
+
word_next, tag_next = segments[ii + 1]
|
| 129 |
+
len_next = len(word_next)
|
| 130 |
+
next_valid = True
|
| 131 |
+
if tag_next != 'w' and not word_next in PKUSEG_PUNCSET:
|
| 132 |
+
score_next = PKUSEGSCORES[tag][tag_next]
|
| 133 |
+
|
| 134 |
+
if ii > 0:
|
| 135 |
+
word_prev, tag_prev = words[-1], segments[ii - 1][1]
|
| 136 |
+
len_prev = len(word_prev)
|
| 137 |
+
prev_valid = True
|
| 138 |
+
if tag_prev != 'w' and not word_prev[-1] in PKUSEG_PUNCSET:
|
| 139 |
+
score_prev = PKUSEGSCORES[tag_prev][tag]
|
| 140 |
+
|
| 141 |
+
append_prev, append_next = False, False
|
| 142 |
+
|
| 143 |
+
if tag == 'w' or word in PKUSEG_PUNCSET: # puntuation
|
| 144 |
+
if word in PUNCTUATION_L:
|
| 145 |
+
append_next = next_valid
|
| 146 |
+
elif len_word <= 1:
|
| 147 |
+
append_prev = prev_valid
|
| 148 |
+
else:
|
| 149 |
+
next_valid = score_next > 0 and len_next < max_concat_len
|
| 150 |
+
prev_valid = score_prev > 0 and len_prev < max_concat_len
|
| 151 |
+
need_concat = len_word < max_concat_len
|
| 152 |
+
append_prev = score_prev == 1
|
| 153 |
+
append_next = score_next == 1
|
| 154 |
+
if score_prev != 1 and score_next != 1 and need_concat:
|
| 155 |
+
append_prev = prev_valid
|
| 156 |
+
append_next = next_valid
|
| 157 |
+
if append_next and append_prev:
|
| 158 |
+
if len_prev == len_next:
|
| 159 |
+
if score_prev >= score_next:
|
| 160 |
+
append_next = False
|
| 161 |
+
else:
|
| 162 |
+
append_prev = False
|
| 163 |
+
elif len_prev < len_next:
|
| 164 |
+
append_next = False
|
| 165 |
+
else:
|
| 166 |
+
append_prev = False
|
| 167 |
+
|
| 168 |
+
if append_next and append_prev:
|
| 169 |
+
words[-1] = word_prev + word + word_next
|
| 170 |
+
tags[-1] = tags[-1] + [tag, tag_next]
|
| 171 |
+
skip_next = True
|
| 172 |
+
elif append_prev:
|
| 173 |
+
words[-1] = words[-1] + word
|
| 174 |
+
tags[-1].append(tag)
|
| 175 |
+
elif append_next:
|
| 176 |
+
words.append(word + word_next)
|
| 177 |
+
tags.append([tag, tag_next])
|
| 178 |
+
skip_next = True
|
| 179 |
+
else:
|
| 180 |
+
words.append(word)
|
| 181 |
+
tags.append([tag])
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print('exp at line: ', text)
|
| 184 |
+
raise e
|
| 185 |
+
return words
|
| 186 |
+
|
| 187 |
+
def seg_ch_pkg(text: str):
|
| 188 |
+
|
| 189 |
+
global CHSEG
|
| 190 |
+
if CHSEG is None:
|
| 191 |
+
try:
|
| 192 |
+
import pkuseg
|
| 193 |
+
except:
|
| 194 |
+
import spacy_pkuseg as pkuseg
|
| 195 |
+
CHSEG = pkuseg.pkuseg(postag=True)
|
| 196 |
+
|
| 197 |
+
# pkuseg won't work with half-width punctuations
|
| 198 |
+
fullen_text = full_len(text).replace(' ', ' ')
|
| 199 |
+
cvt_back = False
|
| 200 |
+
if fullen_text != text:
|
| 201 |
+
cvt_back = True
|
| 202 |
+
text = fullen_text
|
| 203 |
+
|
| 204 |
+
global PKUSEGSCORES
|
| 205 |
+
if PKUSEGSCORES is None:
|
| 206 |
+
with open(PKUSEGPATH, 'r', encoding='utf8') as f:
|
| 207 |
+
PKUSEGSCORES = json.loads(f.read())
|
| 208 |
+
|
| 209 |
+
text_list = text.replace('\n', '').replace(' ', ' ').split(' ')
|
| 210 |
+
result_list = []
|
| 211 |
+
for ii, text in enumerate(text_list):
|
| 212 |
+
words = None
|
| 213 |
+
if text:
|
| 214 |
+
words = _seg_ch_pkg(text)
|
| 215 |
+
if words is not None:
|
| 216 |
+
if ii > 0:
|
| 217 |
+
words[0] = ' ' + words[0]
|
| 218 |
+
result_list.extend(words)
|
| 219 |
+
|
| 220 |
+
if cvt_back:
|
| 221 |
+
# pkuseg w
|
| 222 |
+
result_list = [half_len(word) for word in result_list]
|
| 223 |
+
return result_list
|
| 224 |
+
|
| 225 |
+
def seg_text(text: str, lang: str) -> Tuple[List, str]:
|
| 226 |
+
delimiter = ''
|
| 227 |
+
if lang in LANGSET_CH:
|
| 228 |
+
words = seg_ch_pkg(text)
|
| 229 |
+
elif lang in LANGSET_CJK:
|
| 230 |
+
words = seg_to_chars(text)
|
| 231 |
+
else:
|
| 232 |
+
words = seg_eng(text)
|
| 233 |
+
delimiter = ' '
|
| 234 |
+
return words, delimiter
|
| 235 |
+
|
| 236 |
+
def is_cjk(lang: str) -> bool:
|
| 237 |
+
return lang in LANGSET_CJK
|
utils/textblock.py
ADDED
|
@@ -0,0 +1,908 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Tuple, Callable
|
| 2 |
+
import numpy as np
|
| 3 |
+
from shapely.geometry import Polygon
|
| 4 |
+
import math
|
| 5 |
+
import copy
|
| 6 |
+
import cv2
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
from .imgproc_utils import union_area, xywh2xyxypoly, rotate_polygons, color_difference
|
| 10 |
+
from .structures import Union, List, Dict, field, nested_dataclass
|
| 11 |
+
from .split_text_region import split_textblock as split_text_region
|
| 12 |
+
from .fontformat import FontFormat, LineSpacingType, TextAlignment, fix_fontweight_qt
|
| 13 |
+
from .textblock_mask import canny_flood
|
| 14 |
+
from .textlines_merge import sort_pnts, Quadrilateral, merge_bboxes_text_region
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
LANG_LIST = ['eng', 'ja', 'unknown']
|
| 18 |
+
LANGCLS2IDX = {'eng': 0, 'ja': 1, 'unknown': 2}
|
| 19 |
+
|
| 20 |
+
# https://ayaka.shn.hk/hanregex/
|
| 21 |
+
# https://medium.com/the-artificial-impostor/detecting-chinese-characters-in-unicode-strings-4ac839ba313a
|
| 22 |
+
CJKPATTERN = re.compile(r'[\uac00-\ud7a3\u3040-\u30ff\u4e00-\u9FFF]')
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@nested_dataclass
|
| 26 |
+
class TextBlock:
|
| 27 |
+
xyxy: List = field(default_factory = lambda: [0, 0, 0, 0])
|
| 28 |
+
lines: List = field(default_factory = lambda: [])
|
| 29 |
+
language: str = 'unknown'
|
| 30 |
+
# font_size: float = -1.
|
| 31 |
+
distance: np.ndarray = None
|
| 32 |
+
angle: int = 0
|
| 33 |
+
vec: List = None
|
| 34 |
+
norm: float = -1
|
| 35 |
+
merged: bool = False
|
| 36 |
+
text: List = field(default_factory = lambda : [])
|
| 37 |
+
translation: str = ""
|
| 38 |
+
rich_text: str = ""
|
| 39 |
+
_bounding_rect: List = None
|
| 40 |
+
src_is_vertical: bool = None
|
| 41 |
+
_detected_font_size: float = -1
|
| 42 |
+
det_model: str = None
|
| 43 |
+
|
| 44 |
+
region_mask: np.ndarray = None
|
| 45 |
+
region_inpaint_dict: Dict = None
|
| 46 |
+
|
| 47 |
+
fontformat: FontFormat = field(default_factory=lambda: FontFormat())
|
| 48 |
+
|
| 49 |
+
deprecated_attributes: dict = field(default_factory = lambda: dict())
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def vertical(self):
|
| 53 |
+
return self.fontformat.vertical
|
| 54 |
+
|
| 55 |
+
@vertical.setter
|
| 56 |
+
def vertical(self, value: bool):
|
| 57 |
+
self.fontformat.vertical = value
|
| 58 |
+
|
| 59 |
+
@property
|
| 60 |
+
def font_size(self):
|
| 61 |
+
return self.fontformat.font_size
|
| 62 |
+
|
| 63 |
+
@font_size.setter
|
| 64 |
+
def font_size(self, value: float):
|
| 65 |
+
self.fontformat.font_size = value
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def line_spacing(self):
|
| 69 |
+
return self.fontformat.line_spacing
|
| 70 |
+
|
| 71 |
+
@line_spacing.setter
|
| 72 |
+
def line_spacing(self, value: float):
|
| 73 |
+
self.fontformat.line_spacing = value
|
| 74 |
+
|
| 75 |
+
@property
|
| 76 |
+
def letter_spacing(self):
|
| 77 |
+
return self.fontformat.letter_spacing
|
| 78 |
+
|
| 79 |
+
@letter_spacing.setter
|
| 80 |
+
def letter_spacing(self, value: float):
|
| 81 |
+
self.fontformat.letter_spacing = value
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def font_family(self):
|
| 85 |
+
return self.fontformat.font_family
|
| 86 |
+
|
| 87 |
+
@font_family.setter
|
| 88 |
+
def font_family(self, value: str):
|
| 89 |
+
self.fontformat.font_family = value
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def font_weight(self):
|
| 93 |
+
return self.fontformat.font_weight
|
| 94 |
+
|
| 95 |
+
@font_weight.setter
|
| 96 |
+
def font_weight(self, value: int):
|
| 97 |
+
self.fontformat.font_weight = value
|
| 98 |
+
|
| 99 |
+
@property
|
| 100 |
+
def bold(self):
|
| 101 |
+
return self.fontformat.bold
|
| 102 |
+
|
| 103 |
+
@bold.setter
|
| 104 |
+
def bold(self, value: bool):
|
| 105 |
+
self.fontformat.bold = value
|
| 106 |
+
|
| 107 |
+
@property
|
| 108 |
+
def italic(self):
|
| 109 |
+
return self.fontformat.italic
|
| 110 |
+
|
| 111 |
+
@italic.setter
|
| 112 |
+
def italic(self, value: bool):
|
| 113 |
+
self.fontformat.italic = value
|
| 114 |
+
|
| 115 |
+
@property
|
| 116 |
+
def underline(self):
|
| 117 |
+
return self.fontformat.underline
|
| 118 |
+
|
| 119 |
+
@underline.setter
|
| 120 |
+
def underline(self, value: bool):
|
| 121 |
+
self.fontformat.underline = value
|
| 122 |
+
|
| 123 |
+
@property
|
| 124 |
+
def stroke_width(self):
|
| 125 |
+
return self.fontformat.stroke_width
|
| 126 |
+
|
| 127 |
+
@stroke_width.setter
|
| 128 |
+
def stroke_width(self, value: float):
|
| 129 |
+
self.fontformat.stroke_width = value
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def opacity(self):
|
| 133 |
+
return self.fontformat.opacity
|
| 134 |
+
|
| 135 |
+
@opacity.setter
|
| 136 |
+
def opacity(self, value: float):
|
| 137 |
+
self.fontformat.opacity = value
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def shadow_radius(self):
|
| 141 |
+
return self.fontformat.shadow_radius
|
| 142 |
+
|
| 143 |
+
@shadow_radius.setter
|
| 144 |
+
def shadow_radius(self, value: float):
|
| 145 |
+
self.fontformat.shadow_radius = value
|
| 146 |
+
|
| 147 |
+
@property
|
| 148 |
+
def shadow_strength(self):
|
| 149 |
+
return self.fontformat.shadow_strength
|
| 150 |
+
|
| 151 |
+
@shadow_strength.setter
|
| 152 |
+
def shadow_strength(self, value: float):
|
| 153 |
+
self.fontformat.shadow_strength = value
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def shadow_color(self):
|
| 157 |
+
return self.fontformat.shadow_color
|
| 158 |
+
|
| 159 |
+
@shadow_color.setter
|
| 160 |
+
def shadow_color(self, value: float):
|
| 161 |
+
self.fontformat.shadow_color = value
|
| 162 |
+
|
| 163 |
+
@property
|
| 164 |
+
def shadow_offset(self):
|
| 165 |
+
return self.fontformat.shadow_offset
|
| 166 |
+
|
| 167 |
+
@shadow_offset.setter
|
| 168 |
+
def shadow_offset(self, value: float):
|
| 169 |
+
self.fontformat.shadow_offset = value
|
| 170 |
+
|
| 171 |
+
@property
|
| 172 |
+
def fg_colors(self):
|
| 173 |
+
return self.fontformat.frgb
|
| 174 |
+
|
| 175 |
+
@fg_colors.setter
|
| 176 |
+
def fg_colors(self, value: Union[np.ndarray, List]):
|
| 177 |
+
self.fontformat.frgb = value
|
| 178 |
+
|
| 179 |
+
@property
|
| 180 |
+
def bg_colors(self):
|
| 181 |
+
return self.fontformat.srgb
|
| 182 |
+
|
| 183 |
+
@bg_colors.setter
|
| 184 |
+
def bg_colors(self, value: np.ndarray):
|
| 185 |
+
self.fontformat.srgb = value
|
| 186 |
+
|
| 187 |
+
@property
|
| 188 |
+
def alignment(self):
|
| 189 |
+
return self.fontformat.alignment
|
| 190 |
+
|
| 191 |
+
@alignment.setter
|
| 192 |
+
def alignment(self, value: int):
|
| 193 |
+
self.fontformat.alignment = value
|
| 194 |
+
|
| 195 |
+
def __post_init__(self):
|
| 196 |
+
if self.xyxy is not None:
|
| 197 |
+
self.xyxy = [int(num) for num in self.xyxy]
|
| 198 |
+
if self.distance is not None:
|
| 199 |
+
self.distance = np.array(self.distance, np.float32)
|
| 200 |
+
if self.vec is not None:
|
| 201 |
+
self.vec = np.array(self.vec, np.float32)
|
| 202 |
+
if self.src_is_vertical is None:
|
| 203 |
+
self.src_is_vertical = self.vertical
|
| 204 |
+
|
| 205 |
+
if self.rich_text:
|
| 206 |
+
self.rich_text = fix_fontweight_qt(self.rich_text)
|
| 207 |
+
|
| 208 |
+
da = self.deprecated_attributes
|
| 209 |
+
if len(da) > 0:
|
| 210 |
+
if 'accumulate_color' in da:
|
| 211 |
+
self.fg_colors = np.array([da['fg_r'], da['fg_g'], da['fg_b']], dtype=np.float32)
|
| 212 |
+
self.bg_colors = np.array([da['bg_r'], da['bg_g'], da['bg_b']], dtype=np.float32)
|
| 213 |
+
nlines = len(self)
|
| 214 |
+
if da['accumulate_color'] and len(self) > 0:
|
| 215 |
+
self.fg_colors /= nlines
|
| 216 |
+
self.bg_colors /= nlines
|
| 217 |
+
|
| 218 |
+
deprecated_blk_fmt_keys = {'vertical': None, 'line_spacing': None, 'letter_spacing': None, 'bold': None, 'underline': None, 'italic': None,
|
| 219 |
+
'opacity': None, 'shadow_radius': None, 'shadow_strength': None, 'shadow_color': None, 'shadow_offset': None,
|
| 220 |
+
'font_size': 'size', 'font_family': None, '_alignment': 'alignment', 'default_stroke_width': 'stroke_width', 'font_weight': None,
|
| 221 |
+
'fg_colors': 'frgb', 'bg_colors': 'srgb'
|
| 222 |
+
}
|
| 223 |
+
for src_k, v in da.items():
|
| 224 |
+
if src_k in deprecated_blk_fmt_keys:
|
| 225 |
+
if deprecated_blk_fmt_keys[src_k] is None:
|
| 226 |
+
tgt_k = src_k
|
| 227 |
+
else:
|
| 228 |
+
tgt_k = deprecated_blk_fmt_keys[src_k]
|
| 229 |
+
setattr(self.fontformat, tgt_k, v)
|
| 230 |
+
self.font_weight = fix_fontweight_qt(self.font_weight)
|
| 231 |
+
|
| 232 |
+
del self.deprecated_attributes
|
| 233 |
+
|
| 234 |
+
@property
|
| 235 |
+
def detected_font_size(self):
|
| 236 |
+
if self._detected_font_size > 0:
|
| 237 |
+
return self._detected_font_size
|
| 238 |
+
return self.font_size
|
| 239 |
+
|
| 240 |
+
def adjust_bbox(self, with_bbox=False, x_range=None, y_range=None):
|
| 241 |
+
lines = self.lines_array().astype(np.int32)
|
| 242 |
+
if with_bbox:
|
| 243 |
+
self.xyxy[0] = min(lines[..., 0].min(), self.xyxy[0])
|
| 244 |
+
self.xyxy[1] = min(lines[..., 1].min(), self.xyxy[1])
|
| 245 |
+
self.xyxy[2] = max(lines[..., 0].max(), self.xyxy[2])
|
| 246 |
+
self.xyxy[3] = max(lines[..., 1].max(), self.xyxy[3])
|
| 247 |
+
else:
|
| 248 |
+
self.xyxy[0] = lines[..., 0].min()
|
| 249 |
+
self.xyxy[1] = lines[..., 1].min()
|
| 250 |
+
self.xyxy[2] = lines[..., 0].max()
|
| 251 |
+
self.xyxy[3] = lines[..., 1].max()
|
| 252 |
+
|
| 253 |
+
if x_range is not None:
|
| 254 |
+
self.xyxy[0] = np.clip(self.xyxy[0], x_range[0], x_range[1])
|
| 255 |
+
self.xyxy[2] = np.clip(self.xyxy[2], x_range[0], x_range[1])
|
| 256 |
+
if y_range is not None:
|
| 257 |
+
self.xyxy[1] = np.clip(self.xyxy[1], y_range[0], y_range[1])
|
| 258 |
+
self.xyxy[3] = np.clip(self.xyxy[3], y_range[0], y_range[1])
|
| 259 |
+
|
| 260 |
+
def sort_lines(self):
|
| 261 |
+
if self.distance is not None:
|
| 262 |
+
idx = np.argsort(self.distance)
|
| 263 |
+
self.distance = self.distance[idx]
|
| 264 |
+
lines = np.array(self.lines, dtype=np.int32)
|
| 265 |
+
self.lines = lines[idx].tolist()
|
| 266 |
+
|
| 267 |
+
def lines_array(self, dtype=np.float64):
|
| 268 |
+
return np.array(self.lines, dtype=dtype)
|
| 269 |
+
|
| 270 |
+
def set_lines_by_xywh(self, xywh: np.ndarray, angle=0, x_range=None, y_range=None, adjust_bbox=False):
|
| 271 |
+
if isinstance(xywh, List):
|
| 272 |
+
xywh = np.array(xywh)
|
| 273 |
+
lines = xywh2xyxypoly(np.array([xywh]))
|
| 274 |
+
if angle != 0:
|
| 275 |
+
cx, cy = xywh[0], xywh[1]
|
| 276 |
+
cx += xywh[2] / 2.
|
| 277 |
+
cy += xywh[3] / 2.
|
| 278 |
+
lines = rotate_polygons([cx, cy], lines, angle)
|
| 279 |
+
|
| 280 |
+
lines = lines.reshape(-1, 4, 2)
|
| 281 |
+
if x_range is not None:
|
| 282 |
+
lines[..., 0] = np.clip(lines[..., 0], x_range[0], x_range[1])
|
| 283 |
+
if y_range is not None:
|
| 284 |
+
lines[..., 1] = np.clip(lines[..., 1], y_range[0], y_range[1])
|
| 285 |
+
self.lines = lines.tolist()
|
| 286 |
+
|
| 287 |
+
if adjust_bbox:
|
| 288 |
+
self.adjust_bbox()
|
| 289 |
+
|
| 290 |
+
def aspect_ratio(self) -> float:
|
| 291 |
+
min_rect = self.min_rect()
|
| 292 |
+
middle_pnts = (min_rect[:, [1, 2, 3, 0]] + min_rect) / 2
|
| 293 |
+
norm_v = np.linalg.norm(middle_pnts[:, 2] - middle_pnts[:, 0])
|
| 294 |
+
norm_h = np.linalg.norm(middle_pnts[:, 1] - middle_pnts[:, 3])
|
| 295 |
+
return norm_v / norm_h
|
| 296 |
+
|
| 297 |
+
def center(self) -> np.ndarray:
|
| 298 |
+
xyxy = np.array(self.xyxy)
|
| 299 |
+
return (xyxy[:2] + xyxy[2:]) / 2
|
| 300 |
+
|
| 301 |
+
def unrotated_polygons(self, ids=None) -> np.ndarray:
|
| 302 |
+
angled = self.angle != 0
|
| 303 |
+
center = self.center()
|
| 304 |
+
polygons = self.lines_array().reshape(-1, 8)
|
| 305 |
+
if ids is not None:
|
| 306 |
+
polygons = polygons[ids]
|
| 307 |
+
if angled:
|
| 308 |
+
polygons = rotate_polygons(center, polygons, self.angle)
|
| 309 |
+
return angled, center, polygons
|
| 310 |
+
|
| 311 |
+
def min_rect(self, rotate_back=True, ids=None) -> List[int]:
|
| 312 |
+
angled, center, polygons = self.unrotated_polygons(ids=ids)
|
| 313 |
+
min_x = polygons[:, ::2].min()
|
| 314 |
+
min_y = polygons[:, 1::2].min()
|
| 315 |
+
max_x = polygons[:, ::2].max()
|
| 316 |
+
max_y = polygons[:, 1::2].max()
|
| 317 |
+
min_bbox = np.array([[min_x, min_y, max_x, min_y, max_x, max_y, min_x, max_y]])
|
| 318 |
+
if angled and rotate_back:
|
| 319 |
+
min_bbox = rotate_polygons(center, min_bbox, -self.angle)
|
| 320 |
+
return min_bbox.reshape(-1, 4, 2).astype(np.int64)
|
| 321 |
+
|
| 322 |
+
def normalizd_width_list(self, normalize=True):
|
| 323 |
+
angled, center, polygons = self.unrotated_polygons()
|
| 324 |
+
width_list = []
|
| 325 |
+
for polygon in polygons:
|
| 326 |
+
width_list.append((polygon[[2, 4]] - polygon[[0, 6]]).mean())
|
| 327 |
+
sum_width = sum(width_list)
|
| 328 |
+
if normalize:
|
| 329 |
+
width_list = np.array(width_list)
|
| 330 |
+
width_list = width_list / sum_width
|
| 331 |
+
width_list = width_list.tolist()
|
| 332 |
+
return width_list, sum_width
|
| 333 |
+
|
| 334 |
+
# equivalent to qt's boundingRect, ignore angle
|
| 335 |
+
def bounding_rect(self) -> List[int]:
|
| 336 |
+
if self._bounding_rect is None:
|
| 337 |
+
# if True:
|
| 338 |
+
min_bbox = self.min_rect(rotate_back=False)[0]
|
| 339 |
+
x, y = min_bbox[0]
|
| 340 |
+
w, h = min_bbox[2] - min_bbox[0]
|
| 341 |
+
return [int(x), int(y), int(w), int(h)]
|
| 342 |
+
return self._bounding_rect
|
| 343 |
+
|
| 344 |
+
def __getattribute__(self, name: str):
|
| 345 |
+
if name == 'pts':
|
| 346 |
+
return self.lines_array()
|
| 347 |
+
# else:
|
| 348 |
+
return object.__getattribute__(self, name)
|
| 349 |
+
|
| 350 |
+
def __len__(self):
|
| 351 |
+
return len(self.lines)
|
| 352 |
+
|
| 353 |
+
def __getitem__(self, idx):
|
| 354 |
+
return self.lines[idx]
|
| 355 |
+
|
| 356 |
+
def to_dict(self, deep_copy=False):
|
| 357 |
+
blk_dict = vars(self)
|
| 358 |
+
if deep_copy:
|
| 359 |
+
blk_dict = copy.deepcopy(blk_dict)
|
| 360 |
+
return blk_dict
|
| 361 |
+
|
| 362 |
+
def get_transformed_region(self, img: np.ndarray, idx: int, textheight: int, maxwidth: int = None) -> np.ndarray :
|
| 363 |
+
im_h, im_w = img.shape[:2]
|
| 364 |
+
|
| 365 |
+
line = np.round(np.array(self.lines[idx])).astype(np.int64)
|
| 366 |
+
|
| 367 |
+
if not self.src_is_vertical and self.det_model == 'ctd':
|
| 368 |
+
# ctd detected horizontal bbox is smaller than GT
|
| 369 |
+
expand_size = max(int(self._detected_font_size * 0.1), 3)
|
| 370 |
+
rad = np.deg2rad(self.angle)
|
| 371 |
+
shifted_vec = np.array([[[-1, -1],[1, -1],[1, 1],[-1, 1]]])
|
| 372 |
+
shifted_vec = shifted_vec * np.array([[[np.sin(rad), np.cos(rad)]]]) * expand_size
|
| 373 |
+
line = line + shifted_vec
|
| 374 |
+
line[..., 0] = np.clip(line[..., 0], 0, im_w)
|
| 375 |
+
line[..., 1] = np.clip(line[..., 1], 0, im_h)
|
| 376 |
+
line = np.round(line[0]).astype(np.int64)
|
| 377 |
+
|
| 378 |
+
x1, y1, x2, y2 = line[:, 0].min(), line[:, 1].min(), line[:, 0].max(), line[:, 1].max()
|
| 379 |
+
|
| 380 |
+
x1 = np.clip(x1, 0, im_w)
|
| 381 |
+
y1 = np.clip(y1, 0, im_h)
|
| 382 |
+
x2 = np.clip(x2, 0, im_w)
|
| 383 |
+
y2 = np.clip(y2, 0, im_h)
|
| 384 |
+
img_croped = img[y1: y2, x1: x2]
|
| 385 |
+
|
| 386 |
+
direction = 'v' if self.src_is_vertical else 'h'
|
| 387 |
+
|
| 388 |
+
src_pts = line.copy()
|
| 389 |
+
src_pts[:, 0] -= x1
|
| 390 |
+
src_pts[:, 1] -= y1
|
| 391 |
+
middle_pnt = (src_pts[[1, 2, 3, 0]] + src_pts) / 2
|
| 392 |
+
vec_v = middle_pnt[2] - middle_pnt[0] # vertical vectors of textlines
|
| 393 |
+
vec_h = middle_pnt[1] - middle_pnt[3] # horizontal vectors of textlines
|
| 394 |
+
norm_v = np.linalg.norm(vec_v)
|
| 395 |
+
norm_h = np.linalg.norm(vec_h)
|
| 396 |
+
|
| 397 |
+
if textheight is None:
|
| 398 |
+
if direction == 'h' :
|
| 399 |
+
textheight = int(norm_v)
|
| 400 |
+
else:
|
| 401 |
+
textheight = int(norm_h)
|
| 402 |
+
|
| 403 |
+
if norm_v <= 0 or norm_h <= 0:
|
| 404 |
+
print('invalid textpolygon to target img')
|
| 405 |
+
return np.zeros((textheight, textheight, 3), dtype=np.uint8)
|
| 406 |
+
ratio = norm_v / norm_h
|
| 407 |
+
|
| 408 |
+
if direction == 'h' :
|
| 409 |
+
h = int(textheight)
|
| 410 |
+
w = int(round(textheight / ratio))
|
| 411 |
+
dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32)
|
| 412 |
+
M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
|
| 413 |
+
if M is None:
|
| 414 |
+
print('invalid textpolygon to target img')
|
| 415 |
+
return np.zeros((textheight, textheight, 3), dtype=np.uint8)
|
| 416 |
+
region = cv2.warpPerspective(img_croped, M, (w, h))
|
| 417 |
+
elif direction == 'v' :
|
| 418 |
+
w = int(textheight)
|
| 419 |
+
h = int(round(textheight * ratio))
|
| 420 |
+
dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32)
|
| 421 |
+
M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
|
| 422 |
+
if M is None:
|
| 423 |
+
print('invalid textpolygon to target img')
|
| 424 |
+
return np.zeros((textheight, textheight, 3), dtype=np.uint8)
|
| 425 |
+
region = cv2.warpPerspective(img_croped, M, (w, h))
|
| 426 |
+
region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
| 427 |
+
|
| 428 |
+
if maxwidth is not None:
|
| 429 |
+
h, w = region.shape[: 2]
|
| 430 |
+
if w > maxwidth:
|
| 431 |
+
region = cv2.resize(region, (maxwidth, h))
|
| 432 |
+
|
| 433 |
+
return region
|
| 434 |
+
|
| 435 |
+
def get_text(self) -> str:
|
| 436 |
+
if isinstance(self.text, str):
|
| 437 |
+
return self.text
|
| 438 |
+
text = ''
|
| 439 |
+
for t in self.text:
|
| 440 |
+
if text and t:
|
| 441 |
+
if text[-1].isalpha() and t[0].isalpha() \
|
| 442 |
+
and CJKPATTERN.search(text[-1]) is None \
|
| 443 |
+
and CJKPATTERN.search(t[0]) is None:
|
| 444 |
+
text += ' '
|
| 445 |
+
text += t
|
| 446 |
+
|
| 447 |
+
return text.strip()
|
| 448 |
+
|
| 449 |
+
def set_font_colors(self, fg_colors = None, bg_colors = None):
|
| 450 |
+
if fg_colors is not None:
|
| 451 |
+
self.fg_colors = fg_colors
|
| 452 |
+
if bg_colors is not None:
|
| 453 |
+
self.bg_colors = bg_colors
|
| 454 |
+
|
| 455 |
+
def update_font_colors(self, fg_colors: np.ndarray, bg_colors: np.ndarray):
|
| 456 |
+
nlines = len(self)
|
| 457 |
+
if nlines > 0:
|
| 458 |
+
if not isinstance(fg_colors, np.ndarray):
|
| 459 |
+
fg_colors = np.array(fg_colors, dtype=np.float32)
|
| 460 |
+
if not isinstance(bg_colors, np.ndarray):
|
| 461 |
+
bg_colors = np.array(bg_colors, dtype=np.float32)
|
| 462 |
+
if not isinstance(self.fg_colors, np.ndarray):
|
| 463 |
+
self.fg_colors = np.array(self.fg_colors, dtype=np.float32)
|
| 464 |
+
if not isinstance(self.bg_colors, np.ndarray):
|
| 465 |
+
self.bg_colors = np.array(self.bg_colors, dtype=np.float32)
|
| 466 |
+
self.fg_colors += fg_colors / nlines
|
| 467 |
+
self.bg_colors += bg_colors / nlines
|
| 468 |
+
|
| 469 |
+
def get_font_colors(self, bgr=False):
|
| 470 |
+
|
| 471 |
+
frgb = np.array(self.fg_colors).astype(np.int32)
|
| 472 |
+
brgb = np.array(self.bg_colors).astype(np.int32)
|
| 473 |
+
|
| 474 |
+
if bgr:
|
| 475 |
+
frgb = frgb[::-1]
|
| 476 |
+
brgb = brgb[::-1]
|
| 477 |
+
|
| 478 |
+
return frgb, brgb
|
| 479 |
+
|
| 480 |
+
def xywh(self):
|
| 481 |
+
x, y, w, h = self.xyxy
|
| 482 |
+
return [x, y, w-x, h-y]
|
| 483 |
+
|
| 484 |
+
def recalulate_alignment(self):
|
| 485 |
+
angled, center, polygons = self.unrotated_polygons()
|
| 486 |
+
polygons = polygons.reshape(-1, 4, 2)
|
| 487 |
+
|
| 488 |
+
left_std = np.std(polygons[:, 0, 0])
|
| 489 |
+
right_std = np.std(polygons[:, 1, 0])
|
| 490 |
+
center_std = np.std((polygons[:, 0, 0] + polygons[:, 1, 0]) / 2) * 0.7
|
| 491 |
+
|
| 492 |
+
if left_std < right_std and left_std < center_std:
|
| 493 |
+
self.alignment = TextAlignment.Left
|
| 494 |
+
elif right_std < left_std and right_std < center_std:
|
| 495 |
+
self.alignment = TextAlignment.Right
|
| 496 |
+
else:
|
| 497 |
+
self.alignment = TextAlignment.Center
|
| 498 |
+
|
| 499 |
+
def recalulate_stroke_width(self, color_diff_tol = 15, stroke_width: float = 0.2):
|
| 500 |
+
if color_difference(*self.get_font_colors()) < color_diff_tol:
|
| 501 |
+
self.stroke_width = 0.
|
| 502 |
+
else:
|
| 503 |
+
self.stroke_width = stroke_width
|
| 504 |
+
|
| 505 |
+
def adjust_pos(self, dx: int, dy: int):
|
| 506 |
+
self.xyxy[0] += dx
|
| 507 |
+
self.xyxy[1] += dy
|
| 508 |
+
self.xyxy[2] += dx
|
| 509 |
+
self.xyxy[3] += dy
|
| 510 |
+
if self._bounding_rect is not None:
|
| 511 |
+
self._bounding_rect[0] += dx
|
| 512 |
+
self._bounding_rect[1] += dy
|
| 513 |
+
|
| 514 |
+
def line_coord_valid(self, rect):
|
| 515 |
+
if self.det_model is None:
|
| 516 |
+
return False
|
| 517 |
+
if rect is None:
|
| 518 |
+
rect = self.bounding_rect()
|
| 519 |
+
|
| 520 |
+
min_bbox = self.min_rect(rotate_back=True)[0]
|
| 521 |
+
x1, y1 = min_bbox[0]
|
| 522 |
+
x2, y2 = min_bbox[2]
|
| 523 |
+
w = x2 - x1
|
| 524 |
+
h = y2 - y1
|
| 525 |
+
if w < 1 or h < 1:
|
| 526 |
+
return False
|
| 527 |
+
rx1, ry1, rx2, ry2 = rect
|
| 528 |
+
rx2 += rx1
|
| 529 |
+
ry2 += ry1
|
| 530 |
+
intersect = max(min(x2, rx2) - max(x1, rx1), 0) * max(min(y2, ry2) - max(y1, ry1), 0)
|
| 531 |
+
if intersect == 0:
|
| 532 |
+
return False
|
| 533 |
+
if intersect / (w * h) < 0.6:
|
| 534 |
+
return False
|
| 535 |
+
return True
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def sort_regions(regions: List[TextBlock], right_to_left=None) -> List[TextBlock]:
|
| 539 |
+
# from manga image translator
|
| 540 |
+
# Sort regions from right to left, top to bottom
|
| 541 |
+
|
| 542 |
+
nr = len(regions)
|
| 543 |
+
if right_to_left is None and nr > 0:
|
| 544 |
+
nv = 0
|
| 545 |
+
for r in regions:
|
| 546 |
+
if r.vertical:
|
| 547 |
+
nv += 1
|
| 548 |
+
right_to_left = nv / nr > 0
|
| 549 |
+
|
| 550 |
+
sorted_regions = []
|
| 551 |
+
for region in sorted(regions, key=lambda region: region.center()[1]):
|
| 552 |
+
for i, sorted_region in enumerate(sorted_regions):
|
| 553 |
+
if region.center()[1] > sorted_region.xyxy[3]:
|
| 554 |
+
continue
|
| 555 |
+
if region.center()[1] < sorted_region.xyxy[1]:
|
| 556 |
+
sorted_regions.insert(i + 1, region)
|
| 557 |
+
break
|
| 558 |
+
|
| 559 |
+
# y center of region inside sorted_region so sort by x instead
|
| 560 |
+
if right_to_left and region.center()[0] > sorted_region.center()[0]:
|
| 561 |
+
sorted_regions.insert(i, region)
|
| 562 |
+
break
|
| 563 |
+
if not right_to_left and region.center()[0] < sorted_region.center()[0]:
|
| 564 |
+
sorted_regions.insert(i, region)
|
| 565 |
+
break
|
| 566 |
+
else:
|
| 567 |
+
sorted_regions.append(region)
|
| 568 |
+
return sorted_regions
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
def examine_textblk(blk: TextBlock, im_w: int, im_h: int, sort: bool = False) -> None:
|
| 572 |
+
lines = blk.lines_array()
|
| 573 |
+
middle_pnts = (lines[:, [1, 2, 3, 0]] + lines) / 2
|
| 574 |
+
vec_v = middle_pnts[:, 2] - middle_pnts[:, 0] # vertical vectors of textlines
|
| 575 |
+
vec_h = middle_pnts[:, 1] - middle_pnts[:, 3] # horizontal vectors of textlines
|
| 576 |
+
# if sum of vertical vectors is longer, then text orientation is vertical, and vice versa.
|
| 577 |
+
center_pnts = (lines[:, 0] + lines[:, 2]) / 2
|
| 578 |
+
v = np.sum(vec_v, axis=0)
|
| 579 |
+
h = np.sum(vec_h, axis=0)
|
| 580 |
+
norm_v, norm_h = np.linalg.norm(v), np.linalg.norm(h)
|
| 581 |
+
vertical = blk.src_is_vertical
|
| 582 |
+
# calcuate distance between textlines and origin
|
| 583 |
+
if vertical:
|
| 584 |
+
primary_vec, primary_norm = v, norm_v
|
| 585 |
+
distance_vectors = center_pnts - np.array([[im_w, 0]], dtype=np.float64) # vertical manga text is read from right to left, so origin is (imw, 0)
|
| 586 |
+
font_size = int(round(norm_h / len(lines)))
|
| 587 |
+
else:
|
| 588 |
+
primary_vec, primary_norm = h, norm_h
|
| 589 |
+
distance_vectors = center_pnts - np.array([[0, 0]], dtype=np.float64)
|
| 590 |
+
font_size = int(round(norm_v / len(lines)))
|
| 591 |
+
|
| 592 |
+
rotation_angle = int(math.atan2(primary_vec[1], primary_vec[0]) / math.pi * 180) # rotation angle of textlines
|
| 593 |
+
distance = np.linalg.norm(distance_vectors, axis=1) # distance between textlinecenters and origin
|
| 594 |
+
rad_matrix = np.arccos(np.einsum('ij, j->i', distance_vectors, primary_vec) / (distance * primary_norm))
|
| 595 |
+
distance = np.abs(np.sin(rad_matrix) * distance)
|
| 596 |
+
blk.lines = lines.astype(np.int32).tolist()
|
| 597 |
+
blk.distance = distance
|
| 598 |
+
blk.angle = rotation_angle
|
| 599 |
+
if vertical:
|
| 600 |
+
blk.angle -= 90
|
| 601 |
+
if abs(blk.angle) < 3:
|
| 602 |
+
blk.angle = 0
|
| 603 |
+
blk.font_size = font_size
|
| 604 |
+
blk.vec = primary_vec
|
| 605 |
+
blk.norm = primary_norm
|
| 606 |
+
if sort:
|
| 607 |
+
blk.sort_lines()
|
| 608 |
+
|
| 609 |
+
def try_merge_textline(blk: TextBlock, blk2: TextBlock, fntsize_tol=1.7, distance_tol=2) -> bool:
|
| 610 |
+
if blk2.merged:
|
| 611 |
+
return False
|
| 612 |
+
fntsize_div = blk.font_size / blk2.font_size
|
| 613 |
+
num_l1, num_l2 = len(blk), len(blk2)
|
| 614 |
+
fntsz_avg = (blk.font_size * num_l1 + blk2.font_size * num_l2) / (num_l1 + num_l2)
|
| 615 |
+
vec_prod = blk.vec @ blk2.vec
|
| 616 |
+
vec_sum = blk.vec + blk2.vec
|
| 617 |
+
cos_vec = vec_prod / blk.norm / blk2.norm
|
| 618 |
+
# distance = blk2.distance[-1] - blk.distance[-1]
|
| 619 |
+
# distance_p1 = np.linalg.norm(np.array(blk2.lines[-1][0]) - np.array(blk.lines[-1][0]))
|
| 620 |
+
minrect1 = blk.min_rect(ids=[-1])[0]
|
| 621 |
+
xyxy1 = [*minrect1[0], *minrect1[2]]
|
| 622 |
+
minrect2 = blk2.min_rect(ids=[-1])[0]
|
| 623 |
+
xyxy2 = [*minrect2[0], *minrect2[2]]
|
| 624 |
+
distance_x = max(xyxy1[0], xyxy2[0]) - min(xyxy1[2], xyxy2[2])
|
| 625 |
+
distance_y = max(xyxy1[1], xyxy2[1]) - min(xyxy1[3], xyxy2[3])
|
| 626 |
+
|
| 627 |
+
l1, l2 = Polygon(blk.lines[-1]), Polygon(blk2.lines[-1])
|
| 628 |
+
if not l1.intersects(l2):
|
| 629 |
+
if blk.vertical:
|
| 630 |
+
if distance_y > 0:
|
| 631 |
+
return False
|
| 632 |
+
else:
|
| 633 |
+
if distance_x > 0:
|
| 634 |
+
return False
|
| 635 |
+
if fntsize_div > fntsize_tol or 1 / fntsize_div > fntsize_tol:
|
| 636 |
+
return False
|
| 637 |
+
if abs(cos_vec) < 0.866: # cos30
|
| 638 |
+
return False
|
| 639 |
+
# if distance > distance_tol * fntsz_avg:
|
| 640 |
+
# return False
|
| 641 |
+
if blk.vertical and blk2.vertical and distance_x > fntsz_avg * 0.8:
|
| 642 |
+
return False
|
| 643 |
+
if not blk.vertical and distance_y > fntsz_avg * 0.5:
|
| 644 |
+
return False
|
| 645 |
+
# merge
|
| 646 |
+
for line in blk2.lines:
|
| 647 |
+
blk.lines.append(line)
|
| 648 |
+
blk.vec = vec_sum
|
| 649 |
+
blk.angle = int(round(np.rad2deg(math.atan2(vec_sum[1], vec_sum[0]))))
|
| 650 |
+
if blk.vertical:
|
| 651 |
+
blk.angle -= 90
|
| 652 |
+
blk.norm = np.linalg.norm(vec_sum)
|
| 653 |
+
blk.distance = np.append(blk.distance, blk2.distance[-1])
|
| 654 |
+
blk.font_size = fntsz_avg
|
| 655 |
+
blk2.merged = True
|
| 656 |
+
return True
|
| 657 |
+
|
| 658 |
+
def merge_textlines(blk_list: List[TextBlock]) -> List[TextBlock]:
|
| 659 |
+
if len(blk_list) < 2:
|
| 660 |
+
return blk_list
|
| 661 |
+
blk_list.sort(key=lambda blk: blk.distance[0])
|
| 662 |
+
merged_list = []
|
| 663 |
+
for ii, current_blk in enumerate(blk_list):
|
| 664 |
+
if current_blk.merged:
|
| 665 |
+
continue
|
| 666 |
+
for jj, blk in enumerate(blk_list[ii+1:]):
|
| 667 |
+
try_merge_textline(current_blk, blk)
|
| 668 |
+
merged_list.append(current_blk)
|
| 669 |
+
for blk in merged_list:
|
| 670 |
+
blk.adjust_bbox(with_bbox=False)
|
| 671 |
+
return merged_list
|
| 672 |
+
|
| 673 |
+
def split_textblk(blk: TextBlock):
|
| 674 |
+
font_size, distance, lines = blk.font_size, blk.distance, blk.lines
|
| 675 |
+
l0 = np.array(blk.lines[0])
|
| 676 |
+
lines.sort(key=lambda line: np.linalg.norm(np.array(line[0]) - l0[0]))
|
| 677 |
+
distance_tol = font_size * 2
|
| 678 |
+
current_blk = copy.deepcopy(blk)
|
| 679 |
+
current_blk.lines = [l0]
|
| 680 |
+
sub_blk_list = [current_blk]
|
| 681 |
+
textblock_splitted = False
|
| 682 |
+
for jj, line in enumerate(lines[1:]):
|
| 683 |
+
l1, l2 = Polygon(lines[jj]), Polygon(line)
|
| 684 |
+
split = False
|
| 685 |
+
if not l1.intersects(l2):
|
| 686 |
+
line_disance = abs(distance[jj+1] - distance[jj])
|
| 687 |
+
if line_disance > distance_tol:
|
| 688 |
+
split = True
|
| 689 |
+
elif blk.vertical and abs(blk.angle) < 15:
|
| 690 |
+
if len(current_blk.lines) > 1 or line_disance > font_size:
|
| 691 |
+
split = abs(lines[jj][0][1] - line[0][1]) > font_size
|
| 692 |
+
if split:
|
| 693 |
+
current_blk = copy.deepcopy(current_blk)
|
| 694 |
+
current_blk.lines = [line]
|
| 695 |
+
sub_blk_list.append(current_blk)
|
| 696 |
+
else:
|
| 697 |
+
current_blk.lines.append(line)
|
| 698 |
+
if len(sub_blk_list) > 1:
|
| 699 |
+
textblock_splitted = True
|
| 700 |
+
for current_blk in sub_blk_list:
|
| 701 |
+
current_blk.adjust_bbox(with_bbox=False)
|
| 702 |
+
return textblock_splitted, sub_blk_list
|
| 703 |
+
|
| 704 |
+
def group_output(blks, lines, im_w, im_h, mask=None, sort_blklist=True, canvas=None) -> List[TextBlock]:
|
| 705 |
+
blk_list: List[TextBlock] = []
|
| 706 |
+
scattered_lines = {'ver': [], 'hor': []}
|
| 707 |
+
for bbox, cls, conf in zip(*blks):
|
| 708 |
+
# cls could give wrong result
|
| 709 |
+
blk_list.append(TextBlock(bbox, language=LANG_LIST[cls]))
|
| 710 |
+
|
| 711 |
+
# step1: filter & assign lines to textblocks
|
| 712 |
+
bbox_score_thresh = 0.4
|
| 713 |
+
mask_score_thresh = 0.1
|
| 714 |
+
for ii, line in enumerate(lines):
|
| 715 |
+
line, is_vertical = sort_pnts(line)
|
| 716 |
+
bx1, bx2 = line[:, 0].min(), line[:, 0].max()
|
| 717 |
+
by1, by2 = line[:, 1].min(), line[:, 1].max()
|
| 718 |
+
bbox_score, bbox_idx = -1, -1
|
| 719 |
+
line_area = (by2-by1) * (bx2-bx1)
|
| 720 |
+
for jj, blk in enumerate(blk_list):
|
| 721 |
+
score = union_area(blk.xyxy, [bx1, by1, bx2, by2]) / line_area
|
| 722 |
+
if bbox_score < score:
|
| 723 |
+
bbox_score = score
|
| 724 |
+
bbox_idx = jj
|
| 725 |
+
if bbox_score > bbox_score_thresh:
|
| 726 |
+
blk_list[bbox_idx].lines.append(line)
|
| 727 |
+
blk_list[bbox_idx].adjust_bbox(with_bbox=True)
|
| 728 |
+
else: # if no textblock was assigned, check whether there is "enough" textmask
|
| 729 |
+
if mask is not None:
|
| 730 |
+
mask_score = mask[by1: by2, bx1: bx2].mean() / 255
|
| 731 |
+
if mask_score < mask_score_thresh:
|
| 732 |
+
continue
|
| 733 |
+
blk = TextBlock([bx1, by1, bx2, by2], [line])
|
| 734 |
+
blk.vertical = blk.src_is_vertical = is_vertical
|
| 735 |
+
examine_textblk(blk, im_w, im_h, sort=False)
|
| 736 |
+
if blk.vertical:
|
| 737 |
+
scattered_lines['ver'].append(blk)
|
| 738 |
+
else:
|
| 739 |
+
scattered_lines['hor'].append(blk)
|
| 740 |
+
|
| 741 |
+
# step2: filter textblocks, sort & split textlines
|
| 742 |
+
final_blk_list = []
|
| 743 |
+
for blk in blk_list:
|
| 744 |
+
# filter textblocks
|
| 745 |
+
if len(blk.lines) == 0:
|
| 746 |
+
bx1, by1, bx2, by2 = blk.xyxy
|
| 747 |
+
if mask is not None:
|
| 748 |
+
mask_score = mask[by1: by2, bx1: bx2].mean() / 255
|
| 749 |
+
if mask_score < mask_score_thresh:
|
| 750 |
+
continue
|
| 751 |
+
xywh = np.array([[bx1, by1, bx2-bx1, by2-by1]])
|
| 752 |
+
blk.lines = xywh2xyxypoly(xywh).reshape(-1, 4, 2).tolist()
|
| 753 |
+
else:
|
| 754 |
+
blk.adjust_bbox(with_bbox=False)
|
| 755 |
+
examine_textblk(blk, im_w, im_h, sort=True)
|
| 756 |
+
|
| 757 |
+
# split manga text if there is a distance gap
|
| 758 |
+
textblock_splitted = False
|
| 759 |
+
if len(blk.lines) > 1:
|
| 760 |
+
if blk.language == 'ja':
|
| 761 |
+
textblock_splitted = True
|
| 762 |
+
elif blk.vertical:
|
| 763 |
+
textblock_splitted = True
|
| 764 |
+
# if textblock_splitted:
|
| 765 |
+
# textblock_splitted, sub_blk_list = split_textblk(blk)
|
| 766 |
+
# else:
|
| 767 |
+
sub_blk_list = [blk]
|
| 768 |
+
# modify textblock to fit its textlines
|
| 769 |
+
if not textblock_splitted:
|
| 770 |
+
for blk in sub_blk_list:
|
| 771 |
+
blk.adjust_bbox(with_bbox=True)
|
| 772 |
+
final_blk_list += sub_blk_list
|
| 773 |
+
|
| 774 |
+
_final_blk_list = []
|
| 775 |
+
for blk in final_blk_list:
|
| 776 |
+
if blk.vertical:
|
| 777 |
+
scattered_lines['ver'].append(blk)
|
| 778 |
+
else:
|
| 779 |
+
_final_blk_list.append(blk)
|
| 780 |
+
final_blk_list = _final_blk_list
|
| 781 |
+
# step3: merge scattered lines, sort textblocks by "grid"
|
| 782 |
+
final_blk_list += merge_textlines(scattered_lines['hor'])
|
| 783 |
+
final_blk_list += merge_textlines(scattered_lines['ver'])
|
| 784 |
+
if sort_blklist:
|
| 785 |
+
final_blk_list = sort_regions(final_blk_list, )
|
| 786 |
+
for blk in final_blk_list:
|
| 787 |
+
blk.distance = None
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
if len(final_blk_list) > 1:
|
| 791 |
+
_final_blks = [final_blk_list[0]]
|
| 792 |
+
for blk in final_blk_list[1:]:
|
| 793 |
+
ax1, ay1, ax2, ay2 = blk.xyxy
|
| 794 |
+
keep_blk = True
|
| 795 |
+
aarea = (ax2 - ax1) * (ay2 - ay1) + 1e-6
|
| 796 |
+
for eb in _final_blks:
|
| 797 |
+
bx1, by1, bx2, by2 = eb.xyxy
|
| 798 |
+
x1 = max(ax1, bx1)
|
| 799 |
+
y1 = max(ay1, by1)
|
| 800 |
+
x2 = min(ax2, bx2)
|
| 801 |
+
y2 = min(ay2, by2)
|
| 802 |
+
if y2 < y1 or x2 < x1:
|
| 803 |
+
continue
|
| 804 |
+
inter_area = (y2 - y1) * (x2 - x1)
|
| 805 |
+
if inter_area / aarea > 0.9:
|
| 806 |
+
keep_blk = False
|
| 807 |
+
break
|
| 808 |
+
if keep_blk:
|
| 809 |
+
_final_blks.append(blk)
|
| 810 |
+
final_blk_list = _final_blks
|
| 811 |
+
|
| 812 |
+
for blk in final_blk_list:
|
| 813 |
+
if blk.language != 'ja' and not blk.vertical:
|
| 814 |
+
num_lines = len(blk.lines)
|
| 815 |
+
if num_lines == 0:
|
| 816 |
+
continue
|
| 817 |
+
blk._detected_font_size = blk.font_size
|
| 818 |
+
|
| 819 |
+
return final_blk_list
|
| 820 |
+
|
| 821 |
+
def visualize_textblocks(canvas, blk_list: List[TextBlock]):
|
| 822 |
+
lw = max(round(sum(canvas.shape) / 2 * 0.003), 2) # line width
|
| 823 |
+
for ii, blk in enumerate(blk_list):
|
| 824 |
+
bx1, by1, bx2, by2 = blk.xyxy
|
| 825 |
+
cv2.rectangle(canvas, (bx1, by1), (bx2, by2), (127, 255, 127), lw)
|
| 826 |
+
lines = blk.lines_array(dtype=np.int32)
|
| 827 |
+
for jj, line in enumerate(lines):
|
| 828 |
+
cv2.putText(canvas, str(jj), line[0], cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255,127,0), 1)
|
| 829 |
+
cv2.polylines(canvas, [line], True, (0,127,255), 2)
|
| 830 |
+
cv2.polylines(canvas, [blk.min_rect()], True, (127,127,0), 2)
|
| 831 |
+
center = [int((bx1 + bx2)/2), int((by1 + by2)/2)]
|
| 832 |
+
cv2.putText(canvas, str(blk.angle), center, cv2.FONT_HERSHEY_SIMPLEX, 1, (127,127,255), 2)
|
| 833 |
+
cv2.putText(canvas, str(ii), (bx1, by1 + lw + 2), 0, lw / 3, (255,127,127), max(lw-1, 1), cv2.LINE_AA)
|
| 834 |
+
return canvas
|
| 835 |
+
|
| 836 |
+
def collect_textblock_regions(img: np.ndarray, textblk_lst: List[TextBlock], text_height=48, maxwidth=8100, split_textblk = False, seg_func: Callable = None):
|
| 837 |
+
regions = []
|
| 838 |
+
textblk_lst_indices = []
|
| 839 |
+
for blk_idx, textblk in enumerate(textblk_lst):
|
| 840 |
+
for ii in range(len(textblk)):
|
| 841 |
+
if split_textblk and len(textblk) == 1:
|
| 842 |
+
seg_func = canny_flood
|
| 843 |
+
region = textblk.get_transformed_region(img, ii, None, maxwidth=None)
|
| 844 |
+
mask = seg_func(region)[0]
|
| 845 |
+
split_lines = split_text_region(mask)[0]
|
| 846 |
+
for jj, line in enumerate(split_lines):
|
| 847 |
+
bottom = line[3]
|
| 848 |
+
if len(split_lines) == 1:
|
| 849 |
+
bottom = region.shape[0]
|
| 850 |
+
r = region[line[1]: bottom]
|
| 851 |
+
h, w = r.shape[:2]
|
| 852 |
+
tgt_h, tgt_w = text_height, min(maxwidth, int(text_height / h * w))
|
| 853 |
+
if tgt_h != h or tgt_w != w:
|
| 854 |
+
r = cv2.resize(r, (tgt_w, tgt_h), interpolation=cv2.INTER_LINEAR)
|
| 855 |
+
regions.append(r)
|
| 856 |
+
textblk_lst_indices.append(blk_idx)
|
| 857 |
+
# cv2.imwrite(f'local_region{jj}.jpg', r)
|
| 858 |
+
# cv2.imwrite('local_mask.jpg', mask)
|
| 859 |
+
# cv2.imwrite('local_region.jpg',region)
|
| 860 |
+
else:
|
| 861 |
+
textblk_lst_indices.append(blk_idx)
|
| 862 |
+
region = textblk.get_transformed_region(img, ii, text_height, maxwidth=maxwidth)
|
| 863 |
+
regions.append(region)
|
| 864 |
+
|
| 865 |
+
return regions, textblk_lst_indices
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
def mit_merge_textlines(textlines: List[Quadrilateral], width: int, height: int, verbose: bool = False) -> List[TextBlock]:
|
| 869 |
+
# from https://github.com/zyddnys/manga-image-translator
|
| 870 |
+
quadrilateral_lst = []
|
| 871 |
+
for line in textlines:
|
| 872 |
+
if not isinstance(line, Quadrilateral):
|
| 873 |
+
line = Quadrilateral(np.array(line), '', 1.)
|
| 874 |
+
quadrilateral_lst.append(line)
|
| 875 |
+
textlines = quadrilateral_lst
|
| 876 |
+
|
| 877 |
+
text_regions: List[TextBlock] = []
|
| 878 |
+
textlines_total_area = sum([txtln.area for txtln in textlines])
|
| 879 |
+
for (txtlns, fg_color, bg_color) in merge_bboxes_text_region(textlines, width, height):
|
| 880 |
+
total_logprobs = 0
|
| 881 |
+
for txtln in txtlns:
|
| 882 |
+
total_logprobs += np.log(txtln.prob) * txtln.area
|
| 883 |
+
|
| 884 |
+
total_logprobs /= textlines_total_area
|
| 885 |
+
font_size = int(min([txtln.font_size for txtln in txtlns]))
|
| 886 |
+
angle = np.rad2deg(np.mean([txtln.angle for txtln in txtlns])) - 90
|
| 887 |
+
if abs(angle) < 3:
|
| 888 |
+
angle = 0
|
| 889 |
+
lines = [txtln.pts for txtln in txtlns]
|
| 890 |
+
texts = [txtln.text for txtln in txtlns]
|
| 891 |
+
ffmt = FontFormat(font_size=font_size, frgb=fg_color, srgb=bg_color)
|
| 892 |
+
|
| 893 |
+
nv = 0
|
| 894 |
+
for txtln in txtlns:
|
| 895 |
+
if txtln.direction == 'v':
|
| 896 |
+
nv += 1
|
| 897 |
+
is_vertical = nv >= len(txtlns) // 2
|
| 898 |
+
region = TextBlock(
|
| 899 |
+
lines=lines, text=texts, angle=angle, fontformat=ffmt,
|
| 900 |
+
_detected_font_size=font_size, src_is_vertical=is_vertical, vertical=is_vertical)
|
| 901 |
+
region.adjust_bbox()
|
| 902 |
+
if region.src_is_vertical:
|
| 903 |
+
region.alignment = 1
|
| 904 |
+
else:
|
| 905 |
+
region.recalulate_alignment()
|
| 906 |
+
text_regions.append(region)
|
| 907 |
+
|
| 908 |
+
return text_regions
|
utils/textblock_mask.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
from .imgproc_utils import draw_connected_labels
|
| 5 |
+
from .stroke_width_calculator import strokewidth_check
|
| 6 |
+
|
| 7 |
+
opencv_inpaint = lambda img, mask: cv2.inpaint(img, mask, 3, cv2.INPAINT_NS)
|
| 8 |
+
|
| 9 |
+
def show_img_by_dict(imgdicts):
|
| 10 |
+
for keyname in imgdicts.keys():
|
| 11 |
+
cv2.imshow(keyname, imgdicts[keyname])
|
| 12 |
+
cv2.waitKey(0)
|
| 13 |
+
|
| 14 |
+
# 计算文本rgb均值
|
| 15 |
+
def letter_calculator(img, mask, bground_rgb, show_process=False):
|
| 16 |
+
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| 17 |
+
# rgb to grey
|
| 18 |
+
aver_bground_rgb = 0.299 * bground_rgb[0] + 0.587 * bground_rgb[1] + 0.114 * bground_rgb[2]
|
| 19 |
+
thresh_low = 127
|
| 20 |
+
retval, threshed = cv2.threshold(gray, 127, 255, cv2.THRESH_OTSU)
|
| 21 |
+
|
| 22 |
+
if aver_bground_rgb < thresh_low:
|
| 23 |
+
threshed = 255 - threshed
|
| 24 |
+
threshed = 255 - threshed
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
threshed = cv2.bitwise_and(threshed, mask)
|
| 28 |
+
le_region = np.where(threshed==255)
|
| 29 |
+
mat_region = img[le_region]
|
| 30 |
+
|
| 31 |
+
if mat_region.shape[0] == 0:
|
| 32 |
+
# retval, threshed = cv2.threshold(gray, 20, 255, cv2.THRESH_BINARY)
|
| 33 |
+
# cv2.imshow("xxx", threshed)
|
| 34 |
+
# cv2.imshow("2xxx", img)
|
| 35 |
+
# cv2.waitKey(0)
|
| 36 |
+
return [-1, -1, -1], threshed
|
| 37 |
+
|
| 38 |
+
letter_rgb = np.mean(mat_region, axis=0).astype(int).tolist()
|
| 39 |
+
|
| 40 |
+
if show_process:
|
| 41 |
+
cv2.imshow("thresh", threshed)
|
| 42 |
+
# ocr_protest(threshed)
|
| 43 |
+
imgcp = np.copy(img)
|
| 44 |
+
imgcp *= 0
|
| 45 |
+
imgcp += 127
|
| 46 |
+
imgcp[le_region] = letter_rgb
|
| 47 |
+
cv2.imshow("letter_img", imgcp)
|
| 48 |
+
# cv2.waitKey(0)
|
| 49 |
+
|
| 50 |
+
return letter_rgb, threshed
|
| 51 |
+
|
| 52 |
+
# 预处理让文本颜色提取准确点
|
| 53 |
+
def usm(src):
|
| 54 |
+
blur_img = cv2.GaussianBlur(src, (0, 0), 5)
|
| 55 |
+
usm = cv2.addWeighted(src, 1.5, blur_img, -0.5, 0)
|
| 56 |
+
h, w = src.shape[:2]
|
| 57 |
+
result = np.zeros([h, w*2, 3], dtype=src.dtype)
|
| 58 |
+
result[0:h,0:w,:] = src
|
| 59 |
+
result[0:h,w:2*w,:] = usm
|
| 60 |
+
return usm
|
| 61 |
+
|
| 62 |
+
# 计算文本rgb均值方法2,可能用中位数代替均值会好点
|
| 63 |
+
def textrgb_calculator(img, text_mask, show_process=False):
|
| 64 |
+
text_mask = cv2.erode(text_mask, (3, 3), iterations=1)
|
| 65 |
+
usm_img = usm(img)
|
| 66 |
+
overall_meanrgb = np.mean(usm_img[np.where(text_mask==255)], axis=0)
|
| 67 |
+
if show_process:
|
| 68 |
+
colored_text_board = np.zeros((img.shape[0], img.shape[1], 3), dtype=np.uint8) + 127
|
| 69 |
+
colored_text_board[np.where(text_mask==255)] = overall_meanrgb
|
| 70 |
+
cv2.imshow("usm", usm_img)
|
| 71 |
+
cv2.imshow("textcolor", colored_text_board)
|
| 72 |
+
return overall_meanrgb.astype(np.uint8)
|
| 73 |
+
|
| 74 |
+
# 计算背景rgb均值和标准差
|
| 75 |
+
def bground_calculator(buble_img, back_ground_mask, dilate=True):
|
| 76 |
+
kernel = np.ones((3,3),np.uint8)
|
| 77 |
+
if dilate:
|
| 78 |
+
back_ground_mask = cv2.dilate(back_ground_mask, kernel, iterations = 1)
|
| 79 |
+
bground_region = np.where(back_ground_mask==0)
|
| 80 |
+
sd = -1
|
| 81 |
+
if len(bground_region[0]) != 0:
|
| 82 |
+
pix_array = buble_img[bground_region]
|
| 83 |
+
bground_aver = np.mean(pix_array, axis=0).astype(int)
|
| 84 |
+
pix_array - bground_aver
|
| 85 |
+
gray = cv2.cvtColor(buble_img, cv2.COLOR_RGB2GRAY)
|
| 86 |
+
gray_pixarray = gray[bground_region]
|
| 87 |
+
gray_aver = np.mean(gray_pixarray)
|
| 88 |
+
gray_pixarray = gray_pixarray - gray_aver
|
| 89 |
+
gray_pixarray = np.power(gray_pixarray, 2)
|
| 90 |
+
# gray_pixarray = np.sqrt(gray_pixarray)
|
| 91 |
+
sd = np.mean(gray_pixarray)
|
| 92 |
+
else: bground_aver = np.array([-1, -1, -1])
|
| 93 |
+
|
| 94 |
+
return bground_aver, bground_region, sd
|
| 95 |
+
|
| 96 |
+
# 输入:文本块roi,分割出文本mask,根据mask计算文本bgr均值和标准差,决定纯色覆盖/inpaint修复
|
| 97 |
+
def canny_flood(img, show_process=False, inpaint_sdthresh=10, **kwargs):
|
| 98 |
+
# cv2.setNumThreads(4)
|
| 99 |
+
WHITE = (255, 255, 255)
|
| 100 |
+
BLACK = (0, 0, 0)
|
| 101 |
+
kernel = np.ones((3,3),np.uint8)
|
| 102 |
+
orih, oriw = img.shape[0], img.shape[1]
|
| 103 |
+
scaleR = 1
|
| 104 |
+
if orih > 300 and oriw > 300:
|
| 105 |
+
scaleR = 0.6
|
| 106 |
+
elif orih < 120 or oriw < 120:
|
| 107 |
+
scaleR = 1.4
|
| 108 |
+
|
| 109 |
+
if scaleR != 1:
|
| 110 |
+
h, w = img.shape[0], img.shape[1]
|
| 111 |
+
orimg = np.copy(img)
|
| 112 |
+
img = cv2.resize(img, (int(w*scaleR), int(h*scaleR)), interpolation=cv2.INTER_AREA)
|
| 113 |
+
h, w = img.shape[0], img.shape[1]
|
| 114 |
+
img_area = h * w
|
| 115 |
+
|
| 116 |
+
cpimg = cv2.GaussianBlur(img,(3,3),cv2.BORDER_DEFAULT)
|
| 117 |
+
detected_edges = cv2.Canny(cpimg, 70, 140, L2gradient=True, apertureSize=3)
|
| 118 |
+
cv2.rectangle(detected_edges, (0, 0), (w-1, h-1), WHITE, 1, cv2.LINE_8)
|
| 119 |
+
|
| 120 |
+
cons, hiers = cv2.findContours(detected_edges, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
| 121 |
+
|
| 122 |
+
cv2.rectangle(detected_edges, (0, 0), (w-1, h-1), BLACK, 1, cv2.LINE_8)
|
| 123 |
+
|
| 124 |
+
ballon_mask, outer_index = np.zeros((h, w), np.uint8), -1
|
| 125 |
+
|
| 126 |
+
min_retval = np.inf
|
| 127 |
+
mask = np.zeros((h, w), np.uint8)
|
| 128 |
+
difres = 10
|
| 129 |
+
seedpnt = (int(w/2), int(h/2))
|
| 130 |
+
for ii in range(len(cons)):
|
| 131 |
+
rect = cv2.boundingRect(cons[ii])
|
| 132 |
+
if rect[2]*rect[3] < img_area*0.4:
|
| 133 |
+
continue
|
| 134 |
+
|
| 135 |
+
mask = cv2.drawContours(mask, cons, ii, (255), 2)
|
| 136 |
+
cpmask = np.copy(mask)
|
| 137 |
+
cv2.rectangle(mask, (0, 0), (w-1, h-1), WHITE, 1, cv2.LINE_8)
|
| 138 |
+
retval, _, _, rect = cv2.floodFill(cpmask, mask=None, seedPoint=seedpnt, flags=4, newVal=(127), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
|
| 139 |
+
|
| 140 |
+
if retval <= img_area * 0.3:
|
| 141 |
+
mask = cv2.drawContours(mask, cons, ii, (0), 2)
|
| 142 |
+
if retval < min_retval and retval > img_area * 0.3:
|
| 143 |
+
min_retval = retval
|
| 144 |
+
ballon_mask = cpmask
|
| 145 |
+
|
| 146 |
+
ballon_mask = 127 - ballon_mask
|
| 147 |
+
ballon_mask = cv2.dilate(ballon_mask, kernel,iterations = 1)
|
| 148 |
+
outer_area, _, _, rect = cv2.floodFill(ballon_mask, mask=None, seedPoint=seedpnt, flags=4, newVal=(30), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
|
| 149 |
+
ballon_mask = 30 - ballon_mask
|
| 150 |
+
retval, ballon_mask = cv2.threshold(ballon_mask, 1, 255, cv2.THRESH_BINARY)
|
| 151 |
+
ballon_mask = cv2.bitwise_not(ballon_mask, ballon_mask)
|
| 152 |
+
|
| 153 |
+
detected_edges = cv2.dilate(detected_edges, kernel, iterations = 1)
|
| 154 |
+
for ii in range(2):
|
| 155 |
+
detected_edges = cv2.bitwise_and(detected_edges, ballon_mask)
|
| 156 |
+
mask = np.copy(detected_edges)
|
| 157 |
+
bgarea1, _, _, rect = cv2.floodFill(mask, mask=None, seedPoint=(0, 0), flags=4, newVal=(127), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
|
| 158 |
+
bgarea2, _, _, rect = cv2.floodFill(mask, mask=None, seedPoint=(detected_edges.shape[1]-1, detected_edges.shape[0]-1), flags=4, newVal=(127), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
|
| 159 |
+
txt_area = min(img_area - bgarea1, img_area - bgarea2)
|
| 160 |
+
ratio_ob = txt_area / outer_area
|
| 161 |
+
ballon_mask = cv2.erode(ballon_mask, kernel,iterations = 1)
|
| 162 |
+
if ratio_ob < 0.85:
|
| 163 |
+
break
|
| 164 |
+
|
| 165 |
+
mask = 127 - mask
|
| 166 |
+
retval, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
|
| 167 |
+
if scaleR != 1:
|
| 168 |
+
img = orimg
|
| 169 |
+
ballon_mask = cv2.resize(ballon_mask, (oriw, orih))
|
| 170 |
+
mask = cv2.resize(mask, (oriw, orih))
|
| 171 |
+
|
| 172 |
+
bg_mask = cv2.bitwise_or(mask, 255-ballon_mask)
|
| 173 |
+
mask = cv2.bitwise_and(mask, ballon_mask)
|
| 174 |
+
|
| 175 |
+
bground_aver, bground_region, sd = bground_calculator(img, bg_mask)
|
| 176 |
+
inner_rect = None
|
| 177 |
+
threshed = np.zeros((img.shape[0], img.shape[1]), np.uint8)
|
| 178 |
+
|
| 179 |
+
if bground_aver[0] != -1:
|
| 180 |
+
letter_aver, threshed = letter_calculator(img, mask, bground_aver, show_process=show_process)
|
| 181 |
+
if letter_aver[0] != -1:
|
| 182 |
+
mask = cv2.dilate(threshed, kernel, iterations=1)
|
| 183 |
+
inner_rect = cv2.boundingRect(cv2.findNonZero(mask))
|
| 184 |
+
else: letter_aver = [0, 0, 0]
|
| 185 |
+
|
| 186 |
+
if sd != -1 and sd < inpaint_sdthresh:
|
| 187 |
+
need_inpaint = False
|
| 188 |
+
else:
|
| 189 |
+
need_inpaint = True
|
| 190 |
+
if show_process:
|
| 191 |
+
print(f"\nneed_inpaint: {need_inpaint}, sd: {sd}, {type(inner_rect)}")
|
| 192 |
+
show_img_by_dict({"outermask": ballon_mask, "detect": detected_edges, "mask": mask})
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
if isinstance(inner_rect, tuple):
|
| 196 |
+
inner_rect = [ii for ii in inner_rect]
|
| 197 |
+
if inner_rect is None:
|
| 198 |
+
inner_rect = [-1, -1, -1, -1]
|
| 199 |
+
else:
|
| 200 |
+
inner_rect.append(-1)
|
| 201 |
+
|
| 202 |
+
bground_aver = bground_aver.astype(np.uint8)
|
| 203 |
+
bub_dict = {"rgb": letter_aver,
|
| 204 |
+
"bground_rgb": bground_aver,
|
| 205 |
+
"inner_rect": inner_rect,
|
| 206 |
+
"need_inpaint": need_inpaint}
|
| 207 |
+
return mask, ballon_mask, bub_dict
|
| 208 |
+
|
| 209 |
+
# 输入:文本块roi,分割出文本mask,根据mask计算文本bgr均值和标准差,决定纯色覆盖/inpaint修复
|
| 210 |
+
def connected_canny_flood(img, show_process=False, inpaint_sdthresh=10, apply_strokewidth_check=0, **kwargs):
|
| 211 |
+
|
| 212 |
+
# 寻找最可能是气泡的外轮廓mask
|
| 213 |
+
def find_outermask(img):
|
| 214 |
+
connectivity = 4
|
| 215 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img, connectivity, cv2.CV_16U)
|
| 216 |
+
drawtext = np.zeros((img.shape[0], img.shape[1]), np.uint8)
|
| 217 |
+
|
| 218 |
+
max_ind = np.argmax(stats[:, 4])
|
| 219 |
+
maxbbox_area, sec_ind = -1, -1
|
| 220 |
+
for ind, stat in enumerate(stats):
|
| 221 |
+
if ind != max_ind:
|
| 222 |
+
bbarea = stat[2] * stat[3]
|
| 223 |
+
if bbarea > maxbbox_area:
|
| 224 |
+
maxbbox_area = bbarea
|
| 225 |
+
sec_ind = ind
|
| 226 |
+
drawtext[np.where(labels==max_ind)] = 255
|
| 227 |
+
|
| 228 |
+
cv2.rectangle(drawtext, (0, 0), (img.shape[1]-1, img.shape[0]-1), (0, 0, 0), 1, cv2.LINE_8)
|
| 229 |
+
cons, hiers = cv2.findContours(drawtext, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
| 230 |
+
img_area = img.shape[0] * img.shape[1]
|
| 231 |
+
|
| 232 |
+
rects = np.array([cv2.boundingRect(cnt) for cnt in cons])
|
| 233 |
+
rect_area = np.array([rect[2] * rect[3] for rect in rects])
|
| 234 |
+
quali_ind = np.where(rect_area > img_area * 0.3)[0]
|
| 235 |
+
ballon_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8)
|
| 236 |
+
for ind in quali_ind:
|
| 237 |
+
ballon_mask = cv2.drawContours(ballon_mask, cons, ind, (255), 2)
|
| 238 |
+
|
| 239 |
+
seedpnt = (int(ballon_mask.shape[1]/2), int(ballon_mask.shape[0]/2))
|
| 240 |
+
difres = 10
|
| 241 |
+
retval, _, _, rect = cv2.floodFill(ballon_mask, mask=None, seedPoint=seedpnt, flags=4, newVal=(127), loDiff=(difres, difres, difres), upDiff=(difres, difres, difres))
|
| 242 |
+
ballon_mask = 255 - cv2.threshold(ballon_mask - 127, 1, 255, cv2.THRESH_BINARY)[1]
|
| 243 |
+
return num_labels, labels, stats, centroids, ballon_mask
|
| 244 |
+
|
| 245 |
+
# BGR直接转灰度图可能导致文本区域和背景难以区分,比如测试样例中的黑底红字
|
| 246 |
+
# 但是总有一个通道文本和背景容易区分
|
| 247 |
+
# 返回最容易区分的那个通道
|
| 248 |
+
def ccctest(img, crop_r=0.1):
|
| 249 |
+
# img = usm(img)
|
| 250 |
+
maxh = 100
|
| 251 |
+
if img.shape[0] > maxh:
|
| 252 |
+
scaleR = maxh / img.shape[0]
|
| 253 |
+
im = cv2.resize(img, (int(img.shape[1]*scaleR), int(img.shape[0]*scaleR)), interpolation=cv2.INTER_AREA)
|
| 254 |
+
else:
|
| 255 |
+
im = img
|
| 256 |
+
|
| 257 |
+
textlabel_counter = 0
|
| 258 |
+
reverse = False
|
| 259 |
+
c_ind = 0
|
| 260 |
+
|
| 261 |
+
num_labels, labels, stats, centroids, pseduo_outermask = find_outermask(cv2.threshold(cv2.cvtColor(im, cv2.COLOR_RGB2GRAY), 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)[1])
|
| 262 |
+
grayim = np.expand_dims(np.array(cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)), axis=2)
|
| 263 |
+
im = np.append(im, grayim, axis=2)
|
| 264 |
+
outer_cords = np.where(pseduo_outermask==255)
|
| 265 |
+
for bgr_ind in range(4):
|
| 266 |
+
channel = im[:, :, bgr_ind]
|
| 267 |
+
ret, thresh = cv2.threshold(channel, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
|
| 268 |
+
|
| 269 |
+
tmp_reverse = False
|
| 270 |
+
|
| 271 |
+
if np.mean(thresh[outer_cords]) > 160:
|
| 272 |
+
thresh = 255 - thresh
|
| 273 |
+
tmp_reverse = True
|
| 274 |
+
|
| 275 |
+
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(thresh, 4, cv2.CV_16U)
|
| 276 |
+
# draw_connected_labels(num_labels, labels, stats, centroids)
|
| 277 |
+
# cv2.waitKey(0)
|
| 278 |
+
max_ind = np.argmax(stats[:, 4])
|
| 279 |
+
maxr, minr = 0.5, 0.001
|
| 280 |
+
maxw, maxh = stats[max_ind][2] * maxr, stats[max_ind][3] * maxr
|
| 281 |
+
minarea = im.shape[0] * im.shape[1] * minr
|
| 282 |
+
|
| 283 |
+
tmp_counter = 0
|
| 284 |
+
for stat in stats:
|
| 285 |
+
bboxarea = stat[2] * stat[3]
|
| 286 |
+
if stat[2] < maxw and stat[3] < maxh and bboxarea > minarea:
|
| 287 |
+
tmp_counter += 1
|
| 288 |
+
if tmp_counter > textlabel_counter:
|
| 289 |
+
textlabel_counter = tmp_counter
|
| 290 |
+
c_ind = bgr_ind
|
| 291 |
+
reverse = tmp_reverse
|
| 292 |
+
return c_ind, reverse
|
| 293 |
+
|
| 294 |
+
channel_index, reverse = ccctest(img)
|
| 295 |
+
chanel = img[:, :, channel_index] if channel_index < 3 else cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| 296 |
+
ret, thresh = cv2.threshold(chanel, 1, 255, cv2.THRESH_OTSU+cv2.THRESH_BINARY)
|
| 297 |
+
|
| 298 |
+
# reverse to get white text on black bg
|
| 299 |
+
if reverse:
|
| 300 |
+
thresh = 255 - thresh
|
| 301 |
+
num_labels, labels, stats, centroids, ballon_mask = find_outermask(thresh)
|
| 302 |
+
img_area = img.shape[0] * img.shape[1]
|
| 303 |
+
text_mask = np.zeros((img.shape[0], img.shape[1]), np.uint8)
|
| 304 |
+
max_ind = np.argmax(stats[:, 4])
|
| 305 |
+
for lab in (range(num_labels)):
|
| 306 |
+
stat = stats[lab]
|
| 307 |
+
if lab != max_ind and stat[4] < img_area * 0.4:
|
| 308 |
+
labcord = np.where(labels==lab)
|
| 309 |
+
text_mask[labcord] = 255
|
| 310 |
+
|
| 311 |
+
text_mask = cv2.bitwise_and(text_mask, ballon_mask)
|
| 312 |
+
if apply_strokewidth_check > 0:
|
| 313 |
+
text_mask = strokewidth_check(text_mask, labels, num_labels, stats, debug_type=show_process-1)
|
| 314 |
+
|
| 315 |
+
text_color = textrgb_calculator(img, text_mask, show_process=show_process)
|
| 316 |
+
inner_rect = cv2.boundingRect(cv2.findNonZero(cv2.dilate(text_mask, (3, 3), iterations=1)))
|
| 317 |
+
inner_rect = [ii for ii in inner_rect]
|
| 318 |
+
inner_rect.append(-1)
|
| 319 |
+
|
| 320 |
+
bg_mask = cv2.bitwise_or(text_mask, 255-ballon_mask)
|
| 321 |
+
|
| 322 |
+
bground_aver, bground_region, sd = bground_calculator(img, bg_mask)
|
| 323 |
+
|
| 324 |
+
mask = cv2.GaussianBlur(text_mask,(3,3),cv2.BORDER_DEFAULT)
|
| 325 |
+
_, mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
|
| 326 |
+
if sd != -1 and sd < inpaint_sdthresh:
|
| 327 |
+
need_inpaint = False
|
| 328 |
+
else:
|
| 329 |
+
need_inpaint = True
|
| 330 |
+
|
| 331 |
+
if show_process:
|
| 332 |
+
print(f"\nuse inpaint: {need_inpaint}, sd: {sd}, {type(inner_rect)}")
|
| 333 |
+
draw_connected_labels(num_labels, labels, stats, centroids)
|
| 334 |
+
show_img_by_dict({"thresh": thresh, "ori": img, "outer": ballon_mask, "text": text_mask, "bgmask": bg_mask})
|
| 335 |
+
|
| 336 |
+
bground_aver = bground_aver.astype(np.uint8)
|
| 337 |
+
bub_dict = {"rgb": text_color,
|
| 338 |
+
"bground_rgb": bground_aver,
|
| 339 |
+
"inner_rect": inner_rect,
|
| 340 |
+
"need_inpaint": need_inpaint}
|
| 341 |
+
return mask, ballon_mask, bub_dict
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def existing_mask(img, mask: np.ndarray):
|
| 345 |
+
bub_dict = {"rgb": [0, 0, 0],"bground_rgb": [255, 255, 255],"need_inpaint": True}
|
| 346 |
+
return mask, mask, bub_dict
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def extract_ballon_mask(img: np.ndarray, mask: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 350 |
+
'''
|
| 351 |
+
Given original img and text mask (cropped)
|
| 352 |
+
return ballon mask & non text mask
|
| 353 |
+
'''
|
| 354 |
+
img = cv2.GaussianBlur(img,(3,3),cv2.BORDER_DEFAULT)
|
| 355 |
+
h, w = img.shape[:2]
|
| 356 |
+
text_sum = np.sum(mask)
|
| 357 |
+
cannyed = cv2.Canny(img, 70, 140, L2gradient=True, apertureSize=3)
|
| 358 |
+
e_size = 1
|
| 359 |
+
element = cv2.getStructuringElement(cv2.MORPH_RECT, (2 * e_size + 1, 2 * e_size + 1),(e_size, e_size))
|
| 360 |
+
cannyed = cv2.dilate(cannyed, element, iterations=1)
|
| 361 |
+
br = cv2.boundingRect(cv2.findNonZero(mask))
|
| 362 |
+
br_xyxy = [br[0], br[1], br[0] + br[2], br[1] + br[3]]
|
| 363 |
+
|
| 364 |
+
# draw the bounding rect in case there is no closed ballon
|
| 365 |
+
cv2.rectangle(cannyed, (0, 0), (w-1, h-1), (255, 255, 255), 1, cv2.LINE_8)
|
| 366 |
+
cannyed = cv2.bitwise_and(cannyed, 255 - mask)
|
| 367 |
+
|
| 368 |
+
cons, _ = cv2.findContours(cannyed, cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
|
| 369 |
+
min_ballon_area = w * h
|
| 370 |
+
ballon_mask = None
|
| 371 |
+
non_text_mask = None
|
| 372 |
+
# minimum contour which covers all text mask must be the ballon
|
| 373 |
+
for ii, con in enumerate(cons):
|
| 374 |
+
br_c = cv2.boundingRect(con)
|
| 375 |
+
br_c = [br_c[0], br_c[1], br_c[0] + br_c[2], br_c[1] + br_c[3]]
|
| 376 |
+
if br_c[0] > br_xyxy[0] or br_c[1] > br_xyxy[1] or br_c[2] < br_xyxy[2] or br_c[3] < br_xyxy[3]:
|
| 377 |
+
continue
|
| 378 |
+
tmp = np.zeros_like(cannyed)
|
| 379 |
+
cv2.drawContours(tmp, cons, ii, (255, 255, 255), -1, cv2.LINE_8)
|
| 380 |
+
if cv2.bitwise_and(tmp, mask).sum() >= text_sum:
|
| 381 |
+
con_area = cv2.contourArea(con)
|
| 382 |
+
if con_area < min_ballon_area:
|
| 383 |
+
min_ballon_area = con_area
|
| 384 |
+
ballon_mask = tmp
|
| 385 |
+
if ballon_mask is not None:
|
| 386 |
+
non_text_mask = cv2.bitwise_and(ballon_mask, 255 - mask)
|
| 387 |
+
# cv2.imshow('ballon', ballon_mask)
|
| 388 |
+
# cv2.imshow('non_text', non_text_mask)
|
| 389 |
+
# cv2.imshow('im', img)
|
| 390 |
+
# cv2.imshow('msk', mask)
|
| 391 |
+
# cv2.imshow('canny', cannyed)
|
| 392 |
+
# cv2.waitKey(0)
|
| 393 |
+
|
| 394 |
+
return ballon_mask, non_text_mask
|
utils/textlines_merge.py
ADDED
|
@@ -0,0 +1,568 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import functools
|
| 3 |
+
from typing import Tuple, List, ClassVar, Union, Any, Dict, Set
|
| 4 |
+
from collections import Counter
|
| 5 |
+
try:
|
| 6 |
+
functools.cached_property
|
| 7 |
+
except AttributeError: # Supports Python versions below 3.8
|
| 8 |
+
from backports.cached_property import cached_property
|
| 9 |
+
functools.cached_property = cached_property
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from shapely.geometry import Polygon, MultiPoint
|
| 13 |
+
import cv2
|
| 14 |
+
import networkx as nx
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BBox(object):
|
| 18 |
+
def __init__(self, x: int, y: int, w: int, h: int, text: str, prob: float, fg_r: int = 0, fg_g: int = 0, fg_b: int = 0, bg_r: int = 0, bg_g: int = 0, bg_b: int = 0):
|
| 19 |
+
self.x = x
|
| 20 |
+
self.y = y
|
| 21 |
+
self.w = w
|
| 22 |
+
self.h = h
|
| 23 |
+
self.text = text
|
| 24 |
+
self.prob = prob
|
| 25 |
+
self.fg_r = fg_r
|
| 26 |
+
self.fg_g = fg_g
|
| 27 |
+
self.fg_b = fg_b
|
| 28 |
+
self.bg_r = bg_r
|
| 29 |
+
self.bg_g = bg_g
|
| 30 |
+
self.bg_b = bg_b
|
| 31 |
+
|
| 32 |
+
def width(self):
|
| 33 |
+
return self.w
|
| 34 |
+
|
| 35 |
+
def height(self):
|
| 36 |
+
return self.h
|
| 37 |
+
|
| 38 |
+
def to_points(self):
|
| 39 |
+
tl, tr, br, bl = np.array([self.x, self.y]), np.array([self.x + self.w, self.y]), np.array([self.x + self.w, self.y+ self.h]), np.array([self.x, self.y + self.h])
|
| 40 |
+
return tl, tr, br, bl
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def xywh(self):
|
| 44 |
+
return np.array([self.x, self.y, self.w, self.h], dtype=np.int32)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Quadrilateral(object):
|
| 48 |
+
"""
|
| 49 |
+
Helper for storing textlines that contains various helper functions.
|
| 50 |
+
"""
|
| 51 |
+
def __init__(self, pts: np.ndarray, text: str, prob: float, fg_r: int = 0, fg_g: int = 0, fg_b: int = 0, bg_r: int = 0, bg_g: int = 0, bg_b: int = 0):
|
| 52 |
+
self.pts, is_vertical = sort_pnts(pts)
|
| 53 |
+
if is_vertical:
|
| 54 |
+
self.direction = 'v'
|
| 55 |
+
else:
|
| 56 |
+
self.direction = 'h'
|
| 57 |
+
self.text = text
|
| 58 |
+
self.prob = prob
|
| 59 |
+
self.fg_r = fg_r
|
| 60 |
+
self.fg_g = fg_g
|
| 61 |
+
self.fg_b = fg_b
|
| 62 |
+
self.bg_r = bg_r
|
| 63 |
+
self.bg_g = bg_g
|
| 64 |
+
self.bg_b = bg_b
|
| 65 |
+
self.assigned_direction: str = None
|
| 66 |
+
self.textlines: List[Quadrilateral] = []
|
| 67 |
+
|
| 68 |
+
@functools.cached_property
|
| 69 |
+
def structure(self) -> List[np.ndarray]:
|
| 70 |
+
p1 = ((self.pts[0] + self.pts[1]) / 2).astype(int)
|
| 71 |
+
p2 = ((self.pts[2] + self.pts[3]) / 2).astype(int)
|
| 72 |
+
p3 = ((self.pts[1] + self.pts[2]) / 2).astype(int)
|
| 73 |
+
p4 = ((self.pts[3] + self.pts[0]) / 2).astype(int)
|
| 74 |
+
return [p1, p2, p3, p4]
|
| 75 |
+
|
| 76 |
+
@functools.cached_property
|
| 77 |
+
def valid(self) -> bool:
|
| 78 |
+
[l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
|
| 79 |
+
v1 = l1b - l1a
|
| 80 |
+
v2 = l2b - l2a
|
| 81 |
+
unit_vector_1 = v1 / np.linalg.norm(v1)
|
| 82 |
+
unit_vector_2 = v2 / np.linalg.norm(v2)
|
| 83 |
+
dot_product = np.dot(unit_vector_1, unit_vector_2)
|
| 84 |
+
angle = np.arccos(dot_product) * 180 / np.pi
|
| 85 |
+
return abs(angle - 90) < 10
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def fg_colors(self):
|
| 89 |
+
return np.array([self.fg_r, self.fg_g, self.fg_b])
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def bg_colors(self):
|
| 93 |
+
return np.array([self.bg_r, self.bg_g, self.bg_b])
|
| 94 |
+
|
| 95 |
+
@functools.cached_property
|
| 96 |
+
def aspect_ratio(self) -> float:
|
| 97 |
+
"""hor/ver"""
|
| 98 |
+
[l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
|
| 99 |
+
v1 = l1b - l1a
|
| 100 |
+
v2 = l2b - l2a
|
| 101 |
+
return np.linalg.norm(v2) / np.linalg.norm(v1)
|
| 102 |
+
|
| 103 |
+
@functools.cached_property
|
| 104 |
+
def font_size(self) -> float:
|
| 105 |
+
[l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
|
| 106 |
+
v1 = l1b - l1a
|
| 107 |
+
v2 = l2b - l2a
|
| 108 |
+
return min(np.linalg.norm(v2), np.linalg.norm(v1))
|
| 109 |
+
|
| 110 |
+
def width(self) -> int:
|
| 111 |
+
return self.aabb.w
|
| 112 |
+
|
| 113 |
+
def height(self) -> int:
|
| 114 |
+
return self.aabb.h
|
| 115 |
+
|
| 116 |
+
@functools.cached_property
|
| 117 |
+
def xyxy(self):
|
| 118 |
+
return self.aabb.x, self.aabb.y, self.aabb.x + self.aabb.w, self.aabb.y + self.aabb.h
|
| 119 |
+
|
| 120 |
+
def clip(self, width, height):
|
| 121 |
+
self.pts[:, 0] = np.clip(np.round(self.pts[:, 0]), 0, width)
|
| 122 |
+
self.pts[:, 1] = np.clip(np.round(self.pts[:, 1]), 0, height)
|
| 123 |
+
|
| 124 |
+
# @functools.cached_property
|
| 125 |
+
# def points(self):
|
| 126 |
+
# ans = [a.astype(np.float32) for a in self.structure]
|
| 127 |
+
# return [Point(a[0], a[1]) for a in ans]
|
| 128 |
+
|
| 129 |
+
@functools.cached_property
|
| 130 |
+
def aabb(self) -> BBox:
|
| 131 |
+
kq = self.pts
|
| 132 |
+
max_coord = np.max(kq, axis = 0)
|
| 133 |
+
min_coord = np.min(kq, axis = 0)
|
| 134 |
+
return BBox(min_coord[0], min_coord[1], max_coord[0] - min_coord[0], max_coord[1] - min_coord[1], self.text, self.prob, self.fg_r, self.fg_g, self.fg_b, self.bg_r, self.bg_g, self.bg_b)
|
| 135 |
+
|
| 136 |
+
def get_transformed_region(self, img, direction, textheight) -> np.ndarray:
|
| 137 |
+
[l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
|
| 138 |
+
v_vec = l1b - l1a
|
| 139 |
+
h_vec = l2b - l2a
|
| 140 |
+
ratio = np.linalg.norm(v_vec) / np.linalg.norm(h_vec)
|
| 141 |
+
|
| 142 |
+
src_pts = self.pts.astype(np.int64).copy()
|
| 143 |
+
im_h, im_w = img.shape[:2]
|
| 144 |
+
|
| 145 |
+
x1, y1, x2, y2 = src_pts[:, 0].min(), src_pts[:, 1].min(), src_pts[:, 0].max(), src_pts[:, 1].max()
|
| 146 |
+
x1 = np.clip(x1, 0, im_w)
|
| 147 |
+
y1 = np.clip(y1, 0, im_h)
|
| 148 |
+
x2 = np.clip(x2, 0, im_w)
|
| 149 |
+
y2 = np.clip(y2, 0, im_h)
|
| 150 |
+
# cv2.warpPerspective could overflow if image size is too large, better crop it here
|
| 151 |
+
img_croped = img[y1: y2, x1: x2]
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
src_pts[:, 0] -= x1
|
| 155 |
+
src_pts[:, 1] -= y1
|
| 156 |
+
|
| 157 |
+
self.assigned_direction = direction
|
| 158 |
+
if direction == 'h':
|
| 159 |
+
h = max(int(textheight), 2)
|
| 160 |
+
w = max(int(round(textheight / ratio)), 2)
|
| 161 |
+
dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32)
|
| 162 |
+
M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
|
| 163 |
+
region = cv2.warpPerspective(img_croped, M, (w, h))
|
| 164 |
+
return region
|
| 165 |
+
elif direction == 'v':
|
| 166 |
+
w = max(int(textheight), 2)
|
| 167 |
+
h = max(int(round(textheight * ratio)), 2)
|
| 168 |
+
dst_pts = np.array([[0, 0], [w - 1, 0], [w - 1, h - 1], [0, h - 1]]).astype(np.float32)
|
| 169 |
+
M, _ = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
|
| 170 |
+
region = cv2.warpPerspective(img_croped, M, (w, h))
|
| 171 |
+
region = cv2.rotate(region, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
| 172 |
+
return region
|
| 173 |
+
|
| 174 |
+
@functools.cached_property
|
| 175 |
+
def is_axis_aligned(self) -> bool:
|
| 176 |
+
[l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
|
| 177 |
+
v1 = l1b - l1a
|
| 178 |
+
v2 = l2b - l2a
|
| 179 |
+
e1 = np.array([0, 1])
|
| 180 |
+
e2 = np.array([1, 0])
|
| 181 |
+
unit_vector_1 = v1 / np.linalg.norm(v1)
|
| 182 |
+
unit_vector_2 = v2 / np.linalg.norm(v2)
|
| 183 |
+
if abs(np.dot(unit_vector_1, e1)) < 1e-2 or abs(np.dot(unit_vector_1, e2)) < 1e-2:
|
| 184 |
+
return True
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
@functools.cached_property
|
| 188 |
+
def is_approximate_axis_aligned(self) -> bool:
|
| 189 |
+
[l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
|
| 190 |
+
v1 = l1b - l1a
|
| 191 |
+
v2 = l2b - l2a
|
| 192 |
+
e1 = np.array([0, 1])
|
| 193 |
+
e2 = np.array([1, 0])
|
| 194 |
+
unit_vector_1 = v1 / np.linalg.norm(v1)
|
| 195 |
+
unit_vector_2 = v2 / np.linalg.norm(v2)
|
| 196 |
+
if abs(np.dot(unit_vector_1, e1)) < 0.05 or abs(np.dot(unit_vector_1, e2)) < 0.05 or abs(np.dot(unit_vector_2, e1)) < 0.05 or abs(np.dot(unit_vector_2, e2)) < 0.05:
|
| 197 |
+
return True
|
| 198 |
+
return False
|
| 199 |
+
|
| 200 |
+
@functools.cached_property
|
| 201 |
+
def cosangle(self) -> float:
|
| 202 |
+
[l1a, l1b, l2a, l2b] = [a.astype(np.float32) for a in self.structure]
|
| 203 |
+
v1 = l1b - l1a
|
| 204 |
+
e2 = np.array([1, 0])
|
| 205 |
+
unit_vector_1 = v1 / np.linalg.norm(v1)
|
| 206 |
+
return np.dot(unit_vector_1, e2)
|
| 207 |
+
|
| 208 |
+
@functools.cached_property
|
| 209 |
+
def angle(self) -> float:
|
| 210 |
+
return np.fmod(np.arccos(self.cosangle) + np.pi, np.pi)
|
| 211 |
+
|
| 212 |
+
@functools.cached_property
|
| 213 |
+
def centroid(self) -> np.ndarray:
|
| 214 |
+
return np.average(self.pts, axis = 0)
|
| 215 |
+
|
| 216 |
+
def distance_to_point(self, p: np.ndarray) -> float:
|
| 217 |
+
d = 1.0e20
|
| 218 |
+
for i in range(4):
|
| 219 |
+
d = min(d, distance_point_point(p, self.pts[i]))
|
| 220 |
+
d = min(d, distance_point_lineseg(p, self.pts[i], self.pts[(i + 1) % 4]))
|
| 221 |
+
return d
|
| 222 |
+
|
| 223 |
+
@functools.cached_property
|
| 224 |
+
def polygon(self) -> Polygon:
|
| 225 |
+
return MultiPoint([tuple(self.pts[0]), tuple(self.pts[1]), tuple(self.pts[2]), tuple(self.pts[3])]).convex_hull
|
| 226 |
+
|
| 227 |
+
@functools.cached_property
|
| 228 |
+
def area(self) -> float:
|
| 229 |
+
return self.polygon.area
|
| 230 |
+
|
| 231 |
+
def poly_distance(self, other) -> float:
|
| 232 |
+
return self.polygon.distance(other.polygon)
|
| 233 |
+
|
| 234 |
+
def distance(self, other, rho = 0.5) -> float:
|
| 235 |
+
return self.distance_impl(other, rho)# + 1000 * abs(self.angle - other.angle)
|
| 236 |
+
|
| 237 |
+
def distance_impl(self, other, rho = 0.5) -> float:
|
| 238 |
+
# assert self.assigned_direction == other.assigned_direction
|
| 239 |
+
#return gjk_distance(self.points, other.points)
|
| 240 |
+
# b1 = self.aabb
|
| 241 |
+
# b2 = b2.aabb
|
| 242 |
+
# x1, y1, w1, h1 = b1.x, b1.y, b1.w, b1.h
|
| 243 |
+
# x2, y2, w2, h2 = b2.x, b2.y, b2.w, b2.h
|
| 244 |
+
# return rect_distance(x1, y1, x1 + w1, y1 + h1, x2, y2, x2 + w2, y2 + h2)
|
| 245 |
+
pattern = ''
|
| 246 |
+
if self.assigned_direction == 'h':
|
| 247 |
+
pattern = 'h_left'
|
| 248 |
+
else:
|
| 249 |
+
pattern = 'v_top'
|
| 250 |
+
fs = max(self.font_size, other.font_size)
|
| 251 |
+
if self.assigned_direction == 'h':
|
| 252 |
+
poly1 = MultiPoint([tuple(self.pts[0]), tuple(self.pts[3]), tuple(other.pts[0]), tuple(other.pts[3])]).convex_hull
|
| 253 |
+
poly2 = MultiPoint([tuple(self.pts[2]), tuple(self.pts[1]), tuple(other.pts[2]), tuple(other.pts[1])]).convex_hull
|
| 254 |
+
poly3 = MultiPoint([
|
| 255 |
+
tuple(self.structure[0]),
|
| 256 |
+
tuple(self.structure[1]),
|
| 257 |
+
tuple(other.structure[0]),
|
| 258 |
+
tuple(other.structure[1]),
|
| 259 |
+
]).convex_hull
|
| 260 |
+
dist1 = poly1.area / fs
|
| 261 |
+
dist2 = poly2.area / fs
|
| 262 |
+
dist3 = poly3.area / fs
|
| 263 |
+
if dist1 < fs * rho:
|
| 264 |
+
pattern = 'h_left'
|
| 265 |
+
if dist2 < fs * rho and dist2 < dist1:
|
| 266 |
+
pattern = 'h_right'
|
| 267 |
+
if dist3 < fs * rho and dist3 < dist1 and dist3 < dist2:
|
| 268 |
+
pattern = 'h_middle'
|
| 269 |
+
if pattern == 'h_left':
|
| 270 |
+
return dist(self.pts[0][0], self.pts[0][1], other.pts[0][0], other.pts[0][1])
|
| 271 |
+
elif pattern == 'h_right':
|
| 272 |
+
return dist(self.pts[1][0], self.pts[1][1], other.pts[1][0], other.pts[1][1])
|
| 273 |
+
else:
|
| 274 |
+
return dist(self.structure[0][0], self.structure[0][1], other.structure[0][0], other.structure[0][1])
|
| 275 |
+
else:
|
| 276 |
+
poly1 = MultiPoint([tuple(self.pts[0]), tuple(self.pts[1]), tuple(other.pts[0]), tuple(other.pts[1])]).convex_hull
|
| 277 |
+
poly2 = MultiPoint([tuple(self.pts[2]), tuple(self.pts[3]), tuple(other.pts[2]), tuple(other.pts[3])]).convex_hull
|
| 278 |
+
dist1 = poly1.area / fs
|
| 279 |
+
dist2 = poly2.area / fs
|
| 280 |
+
if dist1 < fs * rho:
|
| 281 |
+
pattern = 'v_top'
|
| 282 |
+
if dist2 < fs * rho and dist2 < dist1:
|
| 283 |
+
pattern = 'v_bottom'
|
| 284 |
+
if pattern == 'v_top':
|
| 285 |
+
return dist(self.pts[0][0], self.pts[0][1], other.pts[0][0], other.pts[0][1])
|
| 286 |
+
else:
|
| 287 |
+
return dist(self.pts[2][0], self.pts[2][1], other.pts[2][0], other.pts[2][1])
|
| 288 |
+
|
| 289 |
+
def copy(self, new_pts: np.ndarray):
|
| 290 |
+
return Quadrilateral(new_pts, self.text, self.prob, *self.fg_colors, *self.bg_colors)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def sort_pnts(pts: np.ndarray):
|
| 294 |
+
'''
|
| 295 |
+
Direction must be provided for sorting.
|
| 296 |
+
The longer structure vector (mean of long side vectors) of input points is used to determine the direction.
|
| 297 |
+
It is reliable enough for text lines but not for blocks.
|
| 298 |
+
'''
|
| 299 |
+
|
| 300 |
+
if isinstance(pts, List):
|
| 301 |
+
pts = np.array(pts)
|
| 302 |
+
assert isinstance(pts, np.ndarray) and pts.shape == (4, 2)
|
| 303 |
+
pairwise_vec = (pts[:, None] - pts[None]).reshape((16, -1))
|
| 304 |
+
pairwise_vec_norm = np.linalg.norm(pairwise_vec, axis=1)
|
| 305 |
+
long_side_ids = np.argsort(pairwise_vec_norm)[[8, 10]]
|
| 306 |
+
long_side_vecs = pairwise_vec[long_side_ids]
|
| 307 |
+
inner_prod = (long_side_vecs[0] * long_side_vecs[1]).sum()
|
| 308 |
+
if inner_prod < 0:
|
| 309 |
+
long_side_vecs[0] = -long_side_vecs[0]
|
| 310 |
+
struc_vec = np.abs(long_side_vecs.mean(axis=0))
|
| 311 |
+
is_vertical = struc_vec[0] <= struc_vec[1]
|
| 312 |
+
|
| 313 |
+
if is_vertical:
|
| 314 |
+
pts = pts[np.argsort(pts[:, 1])]
|
| 315 |
+
pts = pts[[*np.argsort(pts[:2, 0]), *np.argsort(pts[2:, 0])[::-1] + 2]]
|
| 316 |
+
return pts, is_vertical
|
| 317 |
+
else:
|
| 318 |
+
pts = pts[np.argsort(pts[:, 0])]
|
| 319 |
+
pts_sorted = np.zeros_like(pts)
|
| 320 |
+
pts_sorted[[0, 3]] = sorted(pts[[0, 1]], key=lambda x: x[1])
|
| 321 |
+
pts_sorted[[1, 2]] = sorted(pts[[2, 3]], key=lambda x: x[1])
|
| 322 |
+
return pts_sorted, is_vertical
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
def dist(x1, y1, x2, y2):
|
| 326 |
+
return np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def distance_point_point(a: np.ndarray, b: np.ndarray) -> float:
|
| 330 |
+
return np.linalg.norm(a - b)
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
# from https://stackoverflow.com/questions/849211/shortest-distance-between-a-point-and-a-line-segment
|
| 334 |
+
def distance_point_lineseg(p: np.ndarray, p1: np.ndarray, p2: np.ndarray):
|
| 335 |
+
x = p[0]
|
| 336 |
+
y = p[1]
|
| 337 |
+
x1 = p1[0]
|
| 338 |
+
y1 = p1[1]
|
| 339 |
+
x2 = p2[0]
|
| 340 |
+
y2 = p2[1]
|
| 341 |
+
A = x - x1
|
| 342 |
+
B = y - y1
|
| 343 |
+
C = x2 - x1
|
| 344 |
+
D = y2 - y1
|
| 345 |
+
|
| 346 |
+
dot = A * C + B * D
|
| 347 |
+
len_sq = C * C + D * D
|
| 348 |
+
param = -1
|
| 349 |
+
if len_sq != 0:
|
| 350 |
+
param = dot / len_sq
|
| 351 |
+
|
| 352 |
+
if param < 0:
|
| 353 |
+
xx = x1
|
| 354 |
+
yy = y1
|
| 355 |
+
elif param > 1:
|
| 356 |
+
xx = x2
|
| 357 |
+
yy = y2
|
| 358 |
+
else:
|
| 359 |
+
xx = x1 + param * C
|
| 360 |
+
yy = y1 + param * D
|
| 361 |
+
|
| 362 |
+
dx = x - xx
|
| 363 |
+
dy = y - yy
|
| 364 |
+
return np.sqrt(dx * dx + dy * dy)
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def quadrilateral_can_merge_region(a: Quadrilateral, b: Quadrilateral, ratio = 1.9, discard_connection_gap = 2, char_gap_tolerance = 0.6, char_gap_tolerance2 = 1.5, font_size_ratio_tol = 1.5, aspect_ratio_tol = 2) -> bool:
|
| 368 |
+
b1 = a.aabb
|
| 369 |
+
b2 = b.aabb
|
| 370 |
+
char_size = min(a.font_size, b.font_size)
|
| 371 |
+
x1, y1, w1, h1 = b1.x, b1.y, b1.w, b1.h
|
| 372 |
+
x2, y2, w2, h2 = b2.x, b2.y, b2.w, b2.h
|
| 373 |
+
# dist = rect_distance(x1, y1, x1 + w1, y1 + h1, x2, y2, x2 + w2, y2 + h2)
|
| 374 |
+
p1 = Polygon(a.pts)
|
| 375 |
+
p2 = Polygon(b.pts)
|
| 376 |
+
dist = p1.distance(p2)
|
| 377 |
+
if dist > discard_connection_gap * char_size:
|
| 378 |
+
return False
|
| 379 |
+
if max(a.font_size, b.font_size) / char_size > font_size_ratio_tol:
|
| 380 |
+
return False
|
| 381 |
+
if a.aspect_ratio > aspect_ratio_tol and b.aspect_ratio < 1. / aspect_ratio_tol:
|
| 382 |
+
return False
|
| 383 |
+
if b.aspect_ratio > aspect_ratio_tol and a.aspect_ratio < 1. / aspect_ratio_tol:
|
| 384 |
+
return False
|
| 385 |
+
a_aa = a.is_approximate_axis_aligned
|
| 386 |
+
b_aa = b.is_approximate_axis_aligned
|
| 387 |
+
if a_aa and b_aa:
|
| 388 |
+
if dist < char_size * char_gap_tolerance:
|
| 389 |
+
if abs(x1 + w1 // 2 - (x2 + w2 // 2)) < char_gap_tolerance2:
|
| 390 |
+
return True
|
| 391 |
+
if w1 > h1 * ratio and h2 > w2 * ratio:
|
| 392 |
+
return False
|
| 393 |
+
if w2 > h2 * ratio and h1 > w1 * ratio:
|
| 394 |
+
return False
|
| 395 |
+
if w1 > h1 * ratio or w2 > h2 * ratio : # h
|
| 396 |
+
return abs(x1 - x2) < char_size * char_gap_tolerance2 or abs(x1 + w1 - (x2 + w2)) < char_size * char_gap_tolerance2
|
| 397 |
+
elif h1 > w1 * ratio or h2 > w2 * ratio : # v
|
| 398 |
+
return abs(y1 - y2) < char_size * char_gap_tolerance2 or abs(y1 + h1 - (y2 + h2)) < char_size * char_gap_tolerance2
|
| 399 |
+
return False
|
| 400 |
+
else:
|
| 401 |
+
return False
|
| 402 |
+
if True:#not a_aa and not b_aa:
|
| 403 |
+
if abs(a.angle - b.angle) < 15 * np.pi / 180:
|
| 404 |
+
fs_a = a.font_size
|
| 405 |
+
fs_b = b.font_size
|
| 406 |
+
fs = min(fs_a, fs_b)
|
| 407 |
+
if a.poly_distance(b) > fs * char_gap_tolerance2:
|
| 408 |
+
return False
|
| 409 |
+
if abs(fs_a - fs_b) / fs > 0.25:
|
| 410 |
+
return False
|
| 411 |
+
return True
|
| 412 |
+
return False
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def quadrilateral_can_merge_region_coarse(a: Quadrilateral, b: Quadrilateral, discard_connection_gap = 2, font_size_ratio_tol = 0.7) -> bool:
|
| 416 |
+
if a.assigned_direction != b.assigned_direction:
|
| 417 |
+
return False
|
| 418 |
+
if abs(a.angle - b.angle) > 15 * np.pi / 180:
|
| 419 |
+
return False
|
| 420 |
+
fs_a = a.font_size
|
| 421 |
+
fs_b = b.font_size
|
| 422 |
+
fs = min(fs_a, fs_b)
|
| 423 |
+
if abs(fs_a - fs_b) / fs > font_size_ratio_tol:
|
| 424 |
+
return False
|
| 425 |
+
fs = max(fs_a, fs_b)
|
| 426 |
+
dist = a.poly_distance(b)
|
| 427 |
+
if dist > discard_connection_gap * fs:
|
| 428 |
+
return False
|
| 429 |
+
return True
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def split_text_region(
|
| 433 |
+
bboxes: List[Quadrilateral],
|
| 434 |
+
connected_region_indices: Set[int],
|
| 435 |
+
width,
|
| 436 |
+
height,
|
| 437 |
+
gamma = 0.5,
|
| 438 |
+
sigma = 2
|
| 439 |
+
) -> List[Set[int]]:
|
| 440 |
+
|
| 441 |
+
connected_region_indices = list(connected_region_indices)
|
| 442 |
+
|
| 443 |
+
# case 1
|
| 444 |
+
if len(connected_region_indices) == 1:
|
| 445 |
+
return [set(connected_region_indices)]
|
| 446 |
+
|
| 447 |
+
# case 2
|
| 448 |
+
if len(connected_region_indices) == 2:
|
| 449 |
+
fs1 = bboxes[connected_region_indices[0]].font_size
|
| 450 |
+
fs2 = bboxes[connected_region_indices[1]].font_size
|
| 451 |
+
fs = max(fs1, fs2)
|
| 452 |
+
|
| 453 |
+
# print(bboxes[connected_region_indices[0]].pts, bboxes[connected_region_indices[1]].pts)
|
| 454 |
+
# print(fs, bboxes[connected_region_indices[0]].distance(bboxes[connected_region_indices[1]]), (1 + gamma) * fs)
|
| 455 |
+
# print(bboxes[connected_region_indices[0]].angle, bboxes[connected_region_indices[1]].angle, 4 * np.pi / 180)
|
| 456 |
+
|
| 457 |
+
if bboxes[connected_region_indices[0]].distance(bboxes[connected_region_indices[1]]) < (1 + gamma) * fs \
|
| 458 |
+
and abs(bboxes[connected_region_indices[0]].angle - bboxes[connected_region_indices[1]].angle) < 0.2 * np.pi:
|
| 459 |
+
return [set(connected_region_indices)]
|
| 460 |
+
else:
|
| 461 |
+
return [set([connected_region_indices[0]]), set([connected_region_indices[1]])]
|
| 462 |
+
|
| 463 |
+
# case 3
|
| 464 |
+
G = nx.Graph()
|
| 465 |
+
for idx in connected_region_indices:
|
| 466 |
+
G.add_node(idx)
|
| 467 |
+
for (u, v) in itertools.combinations(connected_region_indices, 2):
|
| 468 |
+
G.add_edge(u, v, weight=bboxes[u].distance(bboxes[v]))
|
| 469 |
+
# Get distances from neighbouring bboxes
|
| 470 |
+
edges = nx.algorithms.tree.minimum_spanning_edges(G, algorithm='kruskal', data=True)
|
| 471 |
+
edges = sorted(edges, key=lambda a: a[2]['weight'], reverse=True)
|
| 472 |
+
distances_sorted = [a[2]['weight'] for a in edges]
|
| 473 |
+
fontsize = np.mean([bboxes[idx].font_size for idx in connected_region_indices])
|
| 474 |
+
distances_std = np.std(distances_sorted)
|
| 475 |
+
distances_mean = np.mean(distances_sorted)
|
| 476 |
+
std_threshold = max(0.3 * fontsize + 5, 5)
|
| 477 |
+
|
| 478 |
+
b1, b2 = bboxes[edges[0][0]], bboxes[edges[0][1]]
|
| 479 |
+
max_poly_distance = Polygon(b1.pts).distance(Polygon(b2.pts))
|
| 480 |
+
max_centroid_alignment = min(abs(b1.centroid[0] - b2.centroid[0]), abs(b1.centroid[1] - b2.centroid[1]))
|
| 481 |
+
|
| 482 |
+
# print(edges)
|
| 483 |
+
# print(f'std: {distances_std} < thrshold: {std_threshold}, mean: {distances_mean}')
|
| 484 |
+
# print(f'{distances_sorted[0]} <= {distances_mean + distances_std * sigma}' \
|
| 485 |
+
# f' or {distances_sorted[0]} <= {fontsize * (1 + gamma)}' \
|
| 486 |
+
# f' or {distances_sorted[0] - distances_sorted[1]} < {distances_std * sigma}')
|
| 487 |
+
|
| 488 |
+
if (distances_sorted[0] <= distances_mean + distances_std * sigma \
|
| 489 |
+
or distances_sorted[0] <= fontsize * (1 + gamma)) \
|
| 490 |
+
and (distances_std < std_threshold \
|
| 491 |
+
or max_poly_distance == 0 and max_centroid_alignment < 5):
|
| 492 |
+
return [set(connected_region_indices)]
|
| 493 |
+
else:
|
| 494 |
+
# (split_u, split_v, _) = edges[0]
|
| 495 |
+
# print(f'split between "{bboxes[split_u].pts}", "{bboxes[split_v].pts}"')
|
| 496 |
+
G = nx.Graph()
|
| 497 |
+
for idx in connected_region_indices:
|
| 498 |
+
G.add_node(idx)
|
| 499 |
+
# Split out the most deviating bbox
|
| 500 |
+
for edge in edges[1:]:
|
| 501 |
+
G.add_edge(edge[0], edge[1])
|
| 502 |
+
ans = []
|
| 503 |
+
for node_set in nx.algorithms.components.connected_components(G):
|
| 504 |
+
ans.extend(split_text_region(bboxes, node_set, width, height))
|
| 505 |
+
return ans
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def merge_bboxes_text_region(bboxes: List[Quadrilateral], width, height):
|
| 510 |
+
|
| 511 |
+
# step 1: divide into multiple text region candidates
|
| 512 |
+
G = nx.Graph()
|
| 513 |
+
for i, box in enumerate(bboxes):
|
| 514 |
+
G.add_node(i, box=box)
|
| 515 |
+
|
| 516 |
+
for ((u, ubox), (v, vbox)) in itertools.combinations(enumerate(bboxes), 2):
|
| 517 |
+
# if quadrilateral_can_merge_region_coarse(ubox, vbox):
|
| 518 |
+
if quadrilateral_can_merge_region(ubox, vbox, aspect_ratio_tol=1.3, font_size_ratio_tol=2,
|
| 519 |
+
char_gap_tolerance=1, char_gap_tolerance2=3):
|
| 520 |
+
G.add_edge(u, v)
|
| 521 |
+
|
| 522 |
+
# step 2: postprocess - further split each region
|
| 523 |
+
region_indices: List[Set[int]] = []
|
| 524 |
+
for node_set in nx.algorithms.components.connected_components(G):
|
| 525 |
+
region_indices.extend(split_text_region(bboxes, node_set, width, height))
|
| 526 |
+
|
| 527 |
+
# step 3: return regions
|
| 528 |
+
for node_set in region_indices:
|
| 529 |
+
# for node_set in nx.algorithms.components.connected_components(G):
|
| 530 |
+
nodes = list(node_set)
|
| 531 |
+
txtlns: List[Quadrilateral] = np.array(bboxes)[nodes]
|
| 532 |
+
|
| 533 |
+
# calculate average fg and bg color
|
| 534 |
+
fg_r = round(np.mean([box.fg_r for box in txtlns]))
|
| 535 |
+
fg_g = round(np.mean([box.fg_g for box in txtlns]))
|
| 536 |
+
fg_b = round(np.mean([box.fg_b for box in txtlns]))
|
| 537 |
+
bg_r = round(np.mean([box.bg_r for box in txtlns]))
|
| 538 |
+
bg_g = round(np.mean([box.bg_g for box in txtlns]))
|
| 539 |
+
bg_b = round(np.mean([box.bg_b for box in txtlns]))
|
| 540 |
+
|
| 541 |
+
# majority vote for direction
|
| 542 |
+
dirs = [box.direction for box in txtlns]
|
| 543 |
+
majority_dir_top_2 = Counter(dirs).most_common(2)
|
| 544 |
+
if len(majority_dir_top_2) == 1 :
|
| 545 |
+
majority_dir = majority_dir_top_2[0][0]
|
| 546 |
+
elif majority_dir_top_2[0][1] == majority_dir_top_2[1][1] : # if top 2 have the same counts
|
| 547 |
+
max_aspect_ratio = -100
|
| 548 |
+
for box in txtlns :
|
| 549 |
+
if box.aspect_ratio > max_aspect_ratio :
|
| 550 |
+
max_aspect_ratio = box.aspect_ratio
|
| 551 |
+
majority_dir = box.direction
|
| 552 |
+
if 1.0 / box.aspect_ratio > max_aspect_ratio :
|
| 553 |
+
max_aspect_ratio = 1.0 / box.aspect_ratio
|
| 554 |
+
majority_dir = box.direction
|
| 555 |
+
else :
|
| 556 |
+
majority_dir = majority_dir_top_2[0][0]
|
| 557 |
+
|
| 558 |
+
# sort textlines
|
| 559 |
+
if majority_dir == 'h':
|
| 560 |
+
nodes = sorted(nodes, key=lambda x: bboxes[x].centroid[1])
|
| 561 |
+
elif majority_dir == 'v':
|
| 562 |
+
nodes = sorted(nodes, key=lambda x: -bboxes[x].centroid[0])
|
| 563 |
+
txtlns = np.array(bboxes)[nodes]
|
| 564 |
+
|
| 565 |
+
# yield overall bbox and sorted indices
|
| 566 |
+
yield txtlns, (fg_r, fg_g, fg_b), (bg_r, bg_g, bg_b)
|
| 567 |
+
|
| 568 |
+
|
utils/watermark_utils.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
def apply_watermark_to_pil_image(img_pil: Image.Image, watermark_path: str, opacity: float = 0.7) -> Image.Image:
|
| 5 |
+
"""
|
| 6 |
+
Apply watermark to a PIL image
|
| 7 |
+
|
| 8 |
+
Args:
|
| 9 |
+
img_pil (Image.Image): Source PIL image
|
| 10 |
+
watermark_path (str): Path to watermark image
|
| 11 |
+
opacity (float): Watermark opacity (0.0 - 1.0)
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
Image.Image: Watermarked PIL image
|
| 15 |
+
"""
|
| 16 |
+
if not osp.exists(watermark_path):
|
| 17 |
+
return img_pil
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
watermark = Image.open(watermark_path)
|
| 21 |
+
except Exception:
|
| 22 |
+
return img_pil
|
| 23 |
+
|
| 24 |
+
# Ensure images are in RGBA mode
|
| 25 |
+
if img_pil.mode != 'RGBA':
|
| 26 |
+
img_pil = img_pil.convert('RGBA')
|
| 27 |
+
if watermark.mode != 'RGBA':
|
| 28 |
+
watermark = watermark.convert('RGBA')
|
| 29 |
+
|
| 30 |
+
# Fixed watermark size (adjust as needed)
|
| 31 |
+
WATERMARK_FIXED_WIDTH = 418
|
| 32 |
+
WATERMARK_FIXED_HEIGHT = 120
|
| 33 |
+
|
| 34 |
+
# Resize watermark
|
| 35 |
+
watermark = watermark.resize((WATERMARK_FIXED_WIDTH, WATERMARK_FIXED_HEIGHT), Image.LANCZOS)
|
| 36 |
+
|
| 37 |
+
# Apply opacity
|
| 38 |
+
if opacity < 1.0:
|
| 39 |
+
alpha = watermark.split()[3]
|
| 40 |
+
alpha = alpha.point(lambda p: p * opacity)
|
| 41 |
+
watermark.putalpha(alpha)
|
| 42 |
+
|
| 43 |
+
# Get image dimensions
|
| 44 |
+
img_width, img_height = img_pil.size
|
| 45 |
+
|
| 46 |
+
# Create transparent layer for watermarks
|
| 47 |
+
wm_layer = Image.new('RGBA', img_pil.size, (0, 0, 0, 0))
|
| 48 |
+
|
| 49 |
+
# Calculate watermark positions (bottom to top)
|
| 50 |
+
initial_y = img_height - watermark.height - 10 # 10px from bottom
|
| 51 |
+
x_position = 10 # 10px from left
|
| 52 |
+
|
| 53 |
+
# Repeat watermark vertically
|
| 54 |
+
current_y = initial_y
|
| 55 |
+
while current_y > -watermark.height:
|
| 56 |
+
if current_y < 0:
|
| 57 |
+
# Crop watermark if it goes beyond top boundary
|
| 58 |
+
crop_height = watermark.height + current_y
|
| 59 |
+
if crop_height > 0:
|
| 60 |
+
partial_wm = watermark.crop((0, -current_y, watermark.width, watermark.height))
|
| 61 |
+
wm_layer.paste(partial_wm, (x_position, 0), partial_wm)
|
| 62 |
+
else:
|
| 63 |
+
wm_layer.paste(watermark, (x_position, current_y), watermark)
|
| 64 |
+
|
| 65 |
+
current_y -= 8000 # Vertical spacing (adjust as needed)
|
| 66 |
+
|
| 67 |
+
# Composite original image with watermark layer
|
| 68 |
+
return Image.alpha_composite(img_pil, wm_layer)
|
utils/zluda_config.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# 检测是否包含 ZLUDA 标记
|
| 5 |
+
def zluda_available(device_name):
|
| 6 |
+
return "[ZLUDA]" in device_name
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# 关闭 ZLUDA Cudnn 支持 防止错误
|
| 10 |
+
def enable_zluda_config():
|
| 11 |
+
if hasattr(torch, 'cuda') and torch.cuda.is_available():
|
| 12 |
+
device_name = torch.cuda.get_device_name(0)
|
| 13 |
+
print('Device name: ', device_name)
|
| 14 |
+
print('Cuda is available: ', torch.cuda.is_available())
|
| 15 |
+
print('Cuda version: ', torch.version.cuda)
|
| 16 |
+
print('ZLUDA is available: ', zluda_available(device_name))
|
| 17 |
+
|
| 18 |
+
if zluda_available(device_name):
|
| 19 |
+
torch.backends.cudnn.enabled = False
|
| 20 |
+
cuda_attr = torch.backends.cuda
|
| 21 |
+
if hasattr(cuda_attr, 'enable_flash_sdp'):
|
| 22 |
+
torch.backends.cuda.enable_flash_sdp(False)
|
| 23 |
+
print('Cuda enable flash sdp: ', False)
|
| 24 |
+
if hasattr(cuda_attr, 'enable_math_sdp'):
|
| 25 |
+
torch.backends.cuda.enable_math_sdp(True)
|
| 26 |
+
print('Cuda enable math sdp: ', True)
|
| 27 |
+
if hasattr(cuda_attr, 'enable_mem_efficient_sdp'):
|
| 28 |
+
torch.backends.cuda.enable_mem_efficient_sdp(False)
|
| 29 |
+
print('Cuda enable mem efficient sdp: ', False)
|
| 30 |
+
if hasattr(cuda_attr, 'enable_cudnn_sdp'):
|
| 31 |
+
torch.backends.cuda.enable_cudnn_sdp(False)
|
| 32 |
+
print('Cuda enable cudnn sdp: ', False)
|