| from pathlib import Path; import sys, os; from fnmatch import fnmatch |
| import global_ |
|
|
| TASKS = (0,1,2,3,) |
| TP_enable :bool = 1 |
| world_size_ = int(os.environ.get("WORLD_SIZE", "1")) |
| rank_ = int(os.environ.get("RANK", "0")) |
| local_rank_ = int(os.environ.get("LOCAL_RANK", rank_ )) |
| assert world_size_ >= 1 and 0 <= rank_ < world_size_ |
|
|
| USE_filter_mediapipe_fail_swap = 1 |
| CH14 :bool = False |
| class REFNET: |
| ENABLE :bool = 1 |
| CH9 :bool = 0 |
| task2layerNum = { |
| 0:9, |
| 1:9, |
| 2:9, |
| 3:9, |
| } |
| USE_pts :bool = 1 |
| READ_mediapipe_result_from_cache = 1 |
| ADAM_or_SGD :bool = False |
| N_EPOCHS_TRAIN_REF_AND_MID :int = 1 |
| |
| ZERO1_ENABLE :bool = 0 |
|
|
|
|
| NUM_token = 257 |
|
|
|
|
|
|
| if 1: |
| SD14_filename = "sd-v1-4.ckpt" |
| SD14_localpath = Path("checkpoints") / SD14_filename |
| PRETRAIN_CKPT_PATH = f"checkpoints/pretrained.ckpt" |
| PRETRAIN_JSON_PATH = f"checkpoints/pretrained.json" |
|
|
| |
| assert isinstance(TASKS,tuple) |
| NUM_pts = 95 |
| global_.TP_enable = TP_enable |
| global_.rank_ = rank_ |
|
|
|
|
|
|
| MERGE_CFG_in_one_batch :bool = 1 |
|
|
| import dlib |
| FOR_upcycle_ckpt_GEN_or_USE :bool = 0 |
|
|
| DEBUG = 0 |
| DEBUG_skip_load_ckpt :bool = 0 |
| DBEUG_skip_most_in_Unet_constructor :bool = 0 |
| |
| LOG_debug_level = 0 |
|
|
|
|
| _gate_total_runs = {} |
| _gate_total_calls = {} |
| _gate_k2tu = { |
| 'vis Dataset_vFrame perspectiveWarp' : ( 0, 1, None ), |
| 'vis LatentDiffusion.get_input' : ( 0, 5, None ), |
| 'vis LatentDiffusion.get_input-before_return True' : ( 0, 5, None ), |
| 'vis LatentDiffusion.get_input-before_return False' : ( 0, 1, None ), |
| 'vis LatentDiffusion.conditioning_with_feat' : ( 0, 2, None ), |
| 'vis LatentDiffusion.p_losses--after-apply_model' : ( 0, 2, None ), |
| 'statistics test_batch[0]' : ( 0, 2, None ), |
| "Project config:" : ( 0, 1, None ), |
| "Lightning config:" : ( 0, 1, None ), |
| "logger_cfg=" : ( 0, 1, None ), |
| "Merged modelckpt-cfg:" : ( 0, 1, None ), |
| 'bank get' : ( 0, 1, None ), |
| 'bank set' : ( 0, 1, None ), |
| 'clear' : ( 0, 1, None ), |
| 'mean ct:' : ( 0, 1, None ), |
| "[__iter__]" : ( 1, 1, None ), |
| "[_create_batches]" : ( 1, 1, None ), |
| '[set_task_for_MoE]' : ( 1, 1, None ), |
| 'len_inter' : ( 3, 5, None ), |
| 'non_paired' : ( 3, 5, None ), |
| 'ddim rec bg' : ( 4, 5, None ), |
| '[training step]' : ( 7, 1, None ), |
| 'LatentDiffusion.configure_optimizers params:' : ( 0, 1, None ), |
| 'c.shape' : ( 2, 6, None ), |
| '[conditioning_with_feat return]': ( 0, 6, None ), |
| 'c for refNet' : ( 0, 6, None ), |
| 'hair _c.shape:' : ( 0, 1, None ), |
| 'head _c.shape:' : ( 0, 1, None ), |
| 'task' : ( 9, 1, None ), |
| '_t_norm' : ( 9, 1, None ), |
| 'orig,ID clip,lpips rec lmk:' : (20, 2, None ), |
| 'loss_lpips_1 at 0 0 :' : (10, 4, None ), |
| 'loss_lpips_1 at 0 1 :' : (10, 4, None ), |
| 'loss_lpips_1 at 0 2 :' : (10, 4, None ), |
| 'loss_lpips_1 at 1 0 :' : (10, 4, None ), |
| 'loss_lpips_1 at 1 1 :' : (10, 4, None ), |
| 'loss_lpips_1 at 1 2 :' : (10, 4, None ), |
| 'loss_rec_1 at 0 :' : (10, 4, None ), |
| 'loss_rec_1 at 1 :' : (10, 4, None ), |
| 'orig, ID clip, lpips rec lmk:' : (10, 4, None ), |
| 'c_ref True' : ( 3, 5, None ), |
| 'c_ref False' : ( 1, 1, None ), |
| 'ffn_gate_input' : ( 3, 3, None ), |
| 'vis-ffn_gate_input' : ( 3, 3, None ), |
| '[warning]: no param to sync' : (10,1, None ), |
| '[TP] shared sync counts' : (10,1, None ), |
| '[Conv2d param stats] count, name (sorted desc):': (0,1, None ), |
| 'avg full_name=' : ( 0, 1, None ), |
| } |
| def gate_(id_, *args, **kw, ): |
| |
| if 1 and not ( hasattr(global_,'TP_enable') and global_.TP_enable ): |
| import torch.distributed as dist |
| if dist.is_available() and dist.is_initialized(): |
| if dist.get_rank()!=0: return |
|
|
| global _gate_total_runs, _gate_total_calls |
| tu = _gate_k2tu.get(id_, None) |
| if tu is None: |
| return 0 |
| max_run, interval, prob = tu |
| if max_run==0: |
| return 0 |
| if id_ not in _gate_total_runs: |
| _gate_total_runs[id_] = 0 |
| _gate_total_calls[id_] = 0 |
| if _gate_total_runs[id_] >= max_run: |
| return False |
| _gate_total_calls[id_] += 1 |
| if _gate_total_calls[id_] % interval != 0: |
| return False |
| if prob is not None: |
| import random |
| if random.random() > prob: |
| return False |
| _gate_total_runs[id_] += 1 |
| return True |
|
|
| def str_t(): |
| from datetime import datetime |
| now = datetime.now() |
| month_day = f"{now.month:02d}{now.day:02d}" |
| hour_min_second = f"{now.hour:02d}.{now.minute:02d}.{now.second:02d}" |
| ret = f"{month_day}-{hour_min_second}" |
| return ret |
| def str_t_pid(): |
| if hasattr(global_,'TP_enable') and global_.TP_enable: |
| _suffix = global_.rank_ |
| else: |
| _suffix = os.getpid() |
| return f"{str_t()}-{_suffix}" |
|
|
| def printC(*args, **kw): |
| if gate_(args[0]): |
| return print(*args, **kw) |
|
|
| |
|
|
| os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' |
| |
|
|
|
|
| def path_img_2_path_mask( path_img, check_mask_exists = 1 , reuse_if_exists = True, label_mode="RF12_"): |
| assert label_mode=="RF12_" |
| assert label_mode in ('RF12',"RF12_",), label_mode |
| assert 'semantic_mask' not in str(path_img), path_img |
| path_img = Path(path_img) |
| if 1: |
| _suffix = { |
| |
| "RF12_":'-semantic_mask', |
| }[label_mode] |
| path_mask = path_img.parent / f"{path_img.stem}{_suffix}.png" |
| if check_mask_exists or not reuse_if_exists: |
| if not path_mask.exists() or not reuse_if_exists: |
| from gen_semantic_mask import gen_semantic_mask |
| vis_path = None |
| |
| gen_semantic_mask(path_img, path_mask, label_mode, vis_path, ) |
| return path_mask |
|
|
| from my_py_lib.torchModuleName_util import * |
| if 0: |
| |
| import sys; from IPython.core.ultratb import ColorTB; sys.excepthook = ColorTB() |
| class _color: |
| grey = "\x1b[90m"; green = "\x1b[92m"; yellow = "\x1b[93m" |
| red = "\x1b[91m"; orange = "\033[38;5;208m"; orange_light = "\033[38;5;214m" |
| if 1: |
| import logging |
| class _CustomFormatter(logging.Formatter): |
| |
| format = "%(asctime)s | %(levelname)-5s | %(message)s" |
| reset = "\x1b[0m" |
| FORMATS = { |
| logging.DEBUG: _color.grey + format + reset, |
| logging.INFO: _color.green + format + reset, |
| logging.WARNING: _color.yellow + format + reset, |
| logging.ERROR: _color.red + format + reset, |
| logging.CRITICAL: _color.red + format + reset, |
| } |
| def format(self, record): |
| log_fmt = self.FORMATS.get(record.levelno) |
| formatter = logging.Formatter(log_fmt, datefmt='%H:%M:%S') |
| return formatter.format(record) |
| def setup_colored_logging(): |
| logger = logging.getLogger(); ch = logging.StreamHandler() |
| if LOG_debug_level: logger.setLevel(logging.DEBUG); ch.setLevel(logging.DEBUG) |
| else: logger.setLevel(logging.INFO); ch.setLevel(logging.INFO) |
| ch.setFormatter(_CustomFormatter()); logger.addHandler(ch) |
| setup_colored_logging() |
| if 1: |
| import warnings |
| def _custom_showwarning(msg, category, filename, lineno, file=None, line=None): |
| reset = "\x1b[0m"; c_file_line = _color.grey; c_cate = _color.orange; c_msg = _color.yellow |
| if LOG_debug_level: |
| formatted_message=f"{c_cate}{category.__name__}{reset}: {c_msg}{msg}{reset} {c_file_line}{filename}:{lineno}{reset}" |
| else: formatted_message = f"{c_cate}{category.__name__}{reset}: {c_msg}{msg}{reset}" |
| print(formatted_message) |
| warnings.showwarning = _custom_showwarning |
|
|
| if __name__=='__main__': |
| logging.warning("This is a warning message in yellow"); logging.error("This is an error message in red") |
| warnings.warn("This is a colored warning message") |
|
|