from pathlib import Path; import sys, os; from fnmatch import fnmatch import global_ TASKS = (0,1,2,3,) TP_enable :bool = 1 world_size_ = int(os.environ.get("WORLD_SIZE", "1")) rank_ = int(os.environ.get("RANK", "0")) local_rank_ = int(os.environ.get("LOCAL_RANK", rank_ )) assert world_size_ >= 1 and 0 <= rank_ < world_size_ USE_filter_mediapipe_fail_swap = 1 CH14 :bool = False class REFNET: ENABLE :bool = 1 CH9 :bool = 0 task2layerNum = { # actually used as bool now 0:9, 1:9, 2:9, 3:9, } USE_pts :bool = 1 READ_mediapipe_result_from_cache = 1 ADAM_or_SGD :bool = False # 1 => AdamW ; 0 => sgd N_EPOCHS_TRAIN_REF_AND_MID :int = 1 # ZeRO-1 optimizer sharding (ZeroRedundancyOptimizer). avoid using FSDP, just ZeroRedundancyOptimizer ZERO1_ENABLE :bool = 0 NUM_token = 257 if 1: SD14_filename = "sd-v1-4.ckpt" SD14_localpath = Path("checkpoints") / SD14_filename PRETRAIN_CKPT_PATH = f"checkpoints/pretrained.ckpt" PRETRAIN_JSON_PATH = f"checkpoints/pretrained.json" #------------------------------------------- assert isinstance(TASKS,tuple) NUM_pts = 95 global_.TP_enable = TP_enable global_.rank_ = rank_ MERGE_CFG_in_one_batch :bool = 1 import dlib FOR_upcycle_ckpt_GEN_or_USE :bool = 0 DEBUG = 0 DEBUG_skip_load_ckpt :bool = 0 DBEUG_skip_most_in_Unet_constructor :bool = 0 # import os; os.environ['CUDA_LAUNCH_BLOCKING'] = '1' LOG_debug_level = 0 _gate_total_runs = {} _gate_total_calls = {} _gate_k2tu = { # id 2 (max_run,interval,prob) 'vis Dataset_vFrame perspectiveWarp' : ( 0, 1, None ), 'vis LatentDiffusion.get_input' : ( 0, 5, None ), 'vis LatentDiffusion.get_input-before_return True' : ( 0, 5, None ), 'vis LatentDiffusion.get_input-before_return False' : ( 0, 1, None ), 'vis LatentDiffusion.conditioning_with_feat' : ( 0, 2, None ), 'vis LatentDiffusion.p_losses--after-apply_model' : ( 0, 2, None ), 'statistics test_batch[0]' : ( 0, 2, None ),#-------------infer----------- "Project config:" : ( 0, 1, None ),# ------------for printC (arg[0] as id)----------- "Lightning config:" : ( 0, 1, None ), "logger_cfg=" : ( 0, 1, None ), "Merged modelckpt-cfg:" : ( 0, 1, None ), 'bank get' : ( 0, 1, None ), 'bank set' : ( 0, 1, None ), 'clear' : ( 0, 1, None ), 'mean ct:' : ( 0, 1, None ), "[__iter__]" : ( 1, 1, None ), "[_create_batches]" : ( 1, 1, None ), '[set_task_for_MoE]' : ( 1, 1, None ), 'len_inter' : ( 3, 5, None ), 'non_paired' : ( 3, 5, None ), 'ddim rec bg' : ( 4, 5, None ), '[training step]' : ( 7, 1, None ), 'LatentDiffusion.configure_optimizers params:' : ( 0, 1, None ), 'c.shape' : ( 2, 6, None ), '[conditioning_with_feat return]': ( 0, 6, None ), 'c for refNet' : ( 0, 6, None ), 'hair _c.shape:' : ( 0, 1, None ), 'head _c.shape:' : ( 0, 1, None ), 'task' : ( 9, 1, None ), '_t_norm' : ( 9, 1, None ), 'orig,ID clip,lpips rec lmk:' : (20, 2, None ),#-------------ddim_losses----------- 'loss_lpips_1 at 0 0 :' : (10, 4, None ), 'loss_lpips_1 at 0 1 :' : (10, 4, None ), 'loss_lpips_1 at 0 2 :' : (10, 4, None ), 'loss_lpips_1 at 1 0 :' : (10, 4, None ), 'loss_lpips_1 at 1 1 :' : (10, 4, None ), 'loss_lpips_1 at 1 2 :' : (10, 4, None ), 'loss_rec_1 at 0 :' : (10, 4, None ), 'loss_rec_1 at 1 :' : (10, 4, None ), 'orig, ID clip, lpips rec lmk:' : (10, 4, None ), 'c_ref True' : ( 3, 5, None ), 'c_ref False' : ( 1, 1, None ), 'ffn_gate_input' : ( 3, 3, None ),#-------------MoE----------- 'vis-ffn_gate_input' : ( 3, 3, None ), '[warning]: no param to sync' : (10,1, None ),#-------------TP----------- '[TP] shared sync counts' : (10,1, None ), '[Conv2d param stats] count, name (sorted desc):': (0,1, None ),#-------------upcycle----------- 'avg full_name=' : ( 0, 1, None ), } def gate_(id_, *args, **kw, ): # gate for some vis or print behaviour, just for vis/debug # return 0 if 1 and not ( hasattr(global_,'TP_enable') and global_.TP_enable ): import torch.distributed as dist if dist.is_available() and dist.is_initialized(): if dist.get_rank()!=0: return global _gate_total_runs, _gate_total_calls tu = _gate_k2tu.get(id_, None) if tu is None: return 0 max_run, interval, prob = tu if max_run==0: return 0 if id_ not in _gate_total_runs: # Initialize counters for this ID if not present _gate_total_runs[id_] = 0 _gate_total_calls[id_] = 0 if _gate_total_runs[id_] >= max_run: # Check if we've reached the maximum runs return False _gate_total_calls[id_] += 1 if _gate_total_calls[id_] % interval != 0: return False if prob is not None: import random if random.random() > prob: return False _gate_total_runs[id_] += 1 return True def str_t(): # eg. '0608-17.12.30' from datetime import datetime now = datetime.now() month_day = f"{now.month:02d}{now.day:02d}" hour_min_second = f"{now.hour:02d}.{now.minute:02d}.{now.second:02d}" ret = f"{month_day}-{hour_min_second}" return ret def str_t_pid(): # eg. '0608-17.12.30-180165' if hasattr(global_,'TP_enable') and global_.TP_enable: _suffix = global_.rank_ else: _suffix = os.getpid() return f"{str_t()}-{_suffix}" def printC(*args, **kw): # controled print if gate_(args[0]): return print(*args, **kw) #-------------------- os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # disable tf onednn-related warnings # from skimage.io import imsave def path_img_2_path_mask( path_img, check_mask_exists = 1 , reuse_if_exists = True, label_mode="RF12_"): assert label_mode=="RF12_" assert label_mode in ('RF12',"RF12_",), label_mode assert 'semantic_mask' not in str(path_img), path_img path_img = Path(path_img) if 1: _suffix = { # "RF12" :'-semantic_mask', "RF12_":'-semantic_mask', }[label_mode] path_mask = path_img.parent / f"{path_img.stem}{_suffix}.png" if check_mask_exists or not reuse_if_exists: if not path_mask.exists() or not reuse_if_exists: from gen_semantic_mask import gen_semantic_mask vis_path = None # vis_path = person_folder.parent.parent / 'vis_semantic_mask' / f"{person_stem}--{path_img.stem}.png" gen_semantic_mask(path_img, path_mask, label_mode, vis_path, ) return path_mask from my_py_lib.torchModuleName_util import * if 0: #-------------------- terminal color (only for exceptions/logging/warnings) import sys; from IPython.core.ultratb import ColorTB; sys.excepthook = ColorTB() class _color: # ANSI escape grey = "\x1b[90m"; green = "\x1b[92m"; yellow = "\x1b[93m" red = "\x1b[91m"; orange = "\033[38;5;208m"; orange_light = "\033[38;5;214m" if 1: import logging class _CustomFormatter(logging.Formatter): # format = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s [%(levelname)-8s] %(message)s" format = "%(asctime)s | %(levelname)-5s | %(message)s" reset = "\x1b[0m" FORMATS = { logging.DEBUG: _color.grey + format + reset, logging.INFO: _color.green + format + reset, logging.WARNING: _color.yellow + format + reset, logging.ERROR: _color.red + format + reset, logging.CRITICAL: _color.red + format + reset, } def format(self, record): log_fmt = self.FORMATS.get(record.levelno) formatter = logging.Formatter(log_fmt, datefmt='%H:%M:%S') # <= print only time, not date return formatter.format(record) def setup_colored_logging(): logger = logging.getLogger(); ch = logging.StreamHandler() if LOG_debug_level: logger.setLevel(logging.DEBUG); ch.setLevel(logging.DEBUG) else: logger.setLevel(logging.INFO); ch.setLevel(logging.INFO) ch.setFormatter(_CustomFormatter()); logger.addHandler(ch) setup_colored_logging() if 1: import warnings def _custom_showwarning(msg, category, filename, lineno, file=None, line=None): reset = "\x1b[0m"; c_file_line = _color.grey; c_cate = _color.orange; c_msg = _color.yellow if LOG_debug_level: formatted_message=f"{c_cate}{category.__name__}{reset}: {c_msg}{msg}{reset} {c_file_line}{filename}:{lineno}{reset}" else: formatted_message = f"{c_cate}{category.__name__}{reset}: {c_msg}{msg}{reset}" print(formatted_message) warnings.showwarning = _custom_showwarning if __name__=='__main__': logging.warning("This is a warning message in yellow"); logging.error("This is an error message in red") warnings.warn("This is a colored warning message")