UniBioTransfer / util_and_constant.py
scy639's picture
Upload folder using huggingface_hub
fddd19d verified
from pathlib import Path; import sys, os; from fnmatch import fnmatch
import global_
TASKS = (0,1,2,3,)
TP_enable :bool = 1
world_size_ = int(os.environ.get("WORLD_SIZE", "1"))
rank_ = int(os.environ.get("RANK", "0"))
local_rank_ = int(os.environ.get("LOCAL_RANK", rank_ ))
assert world_size_ >= 1 and 0 <= rank_ < world_size_
USE_filter_mediapipe_fail_swap = 1
CH14 :bool = False
class REFNET:
ENABLE :bool = 1
CH9 :bool = 0
task2layerNum = { # actually used as bool now
0:9,
1:9,
2:9,
3:9,
}
USE_pts :bool = 1
READ_mediapipe_result_from_cache = 1
ADAM_or_SGD :bool = False # 1 => AdamW ; 0 => sgd
N_EPOCHS_TRAIN_REF_AND_MID :int = 1
# ZeRO-1 optimizer sharding (ZeroRedundancyOptimizer). avoid using FSDP, just ZeroRedundancyOptimizer
ZERO1_ENABLE :bool = 0
NUM_token = 257
if 1:
SD14_filename = "sd-v1-4.ckpt"
SD14_localpath = Path("checkpoints") / SD14_filename
PRETRAIN_CKPT_PATH = f"checkpoints/pretrained.ckpt"
PRETRAIN_JSON_PATH = f"checkpoints/pretrained.json"
#-------------------------------------------
assert isinstance(TASKS,tuple)
NUM_pts = 95
global_.TP_enable = TP_enable
global_.rank_ = rank_
MERGE_CFG_in_one_batch :bool = 1
import dlib
FOR_upcycle_ckpt_GEN_or_USE :bool = 0
DEBUG = 0
DEBUG_skip_load_ckpt :bool = 0
DBEUG_skip_most_in_Unet_constructor :bool = 0
# import os; os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
LOG_debug_level = 0
_gate_total_runs = {}
_gate_total_calls = {}
_gate_k2tu = { # id 2 (max_run,interval,prob)
'vis Dataset_vFrame perspectiveWarp' : ( 0, 1, None ),
'vis LatentDiffusion.get_input' : ( 0, 5, None ),
'vis LatentDiffusion.get_input-before_return True' : ( 0, 5, None ),
'vis LatentDiffusion.get_input-before_return False' : ( 0, 1, None ),
'vis LatentDiffusion.conditioning_with_feat' : ( 0, 2, None ),
'vis LatentDiffusion.p_losses--after-apply_model' : ( 0, 2, None ),
'statistics test_batch[0]' : ( 0, 2, None ),#-------------infer-----------
"Project config:" : ( 0, 1, None ),# ------------for printC (arg[0] as id)-----------
"Lightning config:" : ( 0, 1, None ),
"logger_cfg=" : ( 0, 1, None ),
"Merged modelckpt-cfg:" : ( 0, 1, None ),
'bank get' : ( 0, 1, None ),
'bank set' : ( 0, 1, None ),
'clear' : ( 0, 1, None ),
'mean ct:' : ( 0, 1, None ),
"[__iter__]" : ( 1, 1, None ),
"[_create_batches]" : ( 1, 1, None ),
'[set_task_for_MoE]' : ( 1, 1, None ),
'len_inter' : ( 3, 5, None ),
'non_paired' : ( 3, 5, None ),
'ddim rec bg' : ( 4, 5, None ),
'[training step]' : ( 7, 1, None ),
'LatentDiffusion.configure_optimizers params:' : ( 0, 1, None ),
'c.shape' : ( 2, 6, None ),
'[conditioning_with_feat return]': ( 0, 6, None ),
'c for refNet' : ( 0, 6, None ),
'hair _c.shape:' : ( 0, 1, None ),
'head _c.shape:' : ( 0, 1, None ),
'task' : ( 9, 1, None ),
'_t_norm' : ( 9, 1, None ),
'orig,ID clip,lpips rec lmk:' : (20, 2, None ),#-------------ddim_losses-----------
'loss_lpips_1 at 0 0 :' : (10, 4, None ),
'loss_lpips_1 at 0 1 :' : (10, 4, None ),
'loss_lpips_1 at 0 2 :' : (10, 4, None ),
'loss_lpips_1 at 1 0 :' : (10, 4, None ),
'loss_lpips_1 at 1 1 :' : (10, 4, None ),
'loss_lpips_1 at 1 2 :' : (10, 4, None ),
'loss_rec_1 at 0 :' : (10, 4, None ),
'loss_rec_1 at 1 :' : (10, 4, None ),
'orig, ID clip, lpips rec lmk:' : (10, 4, None ),
'c_ref True' : ( 3, 5, None ),
'c_ref False' : ( 1, 1, None ),
'ffn_gate_input' : ( 3, 3, None ),#-------------MoE-----------
'vis-ffn_gate_input' : ( 3, 3, None ),
'[warning]: no param to sync' : (10,1, None ),#-------------TP-----------
'[TP] shared sync counts' : (10,1, None ),
'[Conv2d param stats] count, name (sorted desc):': (0,1, None ),#-------------upcycle-----------
'avg full_name=' : ( 0, 1, None ),
}
def gate_(id_, *args, **kw, ): # gate for some vis or print behaviour, just for vis/debug
# return 0
if 1 and not ( hasattr(global_,'TP_enable') and global_.TP_enable ):
import torch.distributed as dist
if dist.is_available() and dist.is_initialized():
if dist.get_rank()!=0: return
global _gate_total_runs, _gate_total_calls
tu = _gate_k2tu.get(id_, None)
if tu is None:
return 0
max_run, interval, prob = tu
if max_run==0:
return 0
if id_ not in _gate_total_runs: # Initialize counters for this ID if not present
_gate_total_runs[id_] = 0
_gate_total_calls[id_] = 0
if _gate_total_runs[id_] >= max_run: # Check if we've reached the maximum runs
return False
_gate_total_calls[id_] += 1
if _gate_total_calls[id_] % interval != 0:
return False
if prob is not None:
import random
if random.random() > prob:
return False
_gate_total_runs[id_] += 1
return True
def str_t(): # eg. '0608-17.12.30'
from datetime import datetime
now = datetime.now()
month_day = f"{now.month:02d}{now.day:02d}"
hour_min_second = f"{now.hour:02d}.{now.minute:02d}.{now.second:02d}"
ret = f"{month_day}-{hour_min_second}"
return ret
def str_t_pid(): # eg. '0608-17.12.30-180165'
if hasattr(global_,'TP_enable') and global_.TP_enable:
_suffix = global_.rank_
else:
_suffix = os.getpid()
return f"{str_t()}-{_suffix}"
def printC(*args, **kw): # controled print
if gate_(args[0]):
return print(*args, **kw)
#--------------------
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0' # disable tf onednn-related warnings
# from skimage.io import imsave
def path_img_2_path_mask( path_img, check_mask_exists = 1 , reuse_if_exists = True, label_mode="RF12_"):
assert label_mode=="RF12_"
assert label_mode in ('RF12',"RF12_",), label_mode
assert 'semantic_mask' not in str(path_img), path_img
path_img = Path(path_img)
if 1:
_suffix = {
# "RF12" :'-semantic_mask',
"RF12_":'-semantic_mask',
}[label_mode]
path_mask = path_img.parent / f"{path_img.stem}{_suffix}.png"
if check_mask_exists or not reuse_if_exists:
if not path_mask.exists() or not reuse_if_exists:
from gen_semantic_mask import gen_semantic_mask
vis_path = None
# vis_path = person_folder.parent.parent / 'vis_semantic_mask' / f"{person_stem}--{path_img.stem}.png"
gen_semantic_mask(path_img, path_mask, label_mode, vis_path, )
return path_mask
from my_py_lib.torchModuleName_util import *
if 0:
#-------------------- terminal color (only for exceptions/logging/warnings)
import sys; from IPython.core.ultratb import ColorTB; sys.excepthook = ColorTB()
class _color: # ANSI escape
grey = "\x1b[90m"; green = "\x1b[92m"; yellow = "\x1b[93m"
red = "\x1b[91m"; orange = "\033[38;5;208m"; orange_light = "\033[38;5;214m"
if 1:
import logging
class _CustomFormatter(logging.Formatter):
# format = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s [%(levelname)-8s] %(message)s"
format = "%(asctime)s | %(levelname)-5s | %(message)s"
reset = "\x1b[0m"
FORMATS = {
logging.DEBUG: _color.grey + format + reset,
logging.INFO: _color.green + format + reset,
logging.WARNING: _color.yellow + format + reset,
logging.ERROR: _color.red + format + reset,
logging.CRITICAL: _color.red + format + reset,
}
def format(self, record):
log_fmt = self.FORMATS.get(record.levelno)
formatter = logging.Formatter(log_fmt, datefmt='%H:%M:%S') # <= print only time, not date
return formatter.format(record)
def setup_colored_logging():
logger = logging.getLogger(); ch = logging.StreamHandler()
if LOG_debug_level: logger.setLevel(logging.DEBUG); ch.setLevel(logging.DEBUG)
else: logger.setLevel(logging.INFO); ch.setLevel(logging.INFO)
ch.setFormatter(_CustomFormatter()); logger.addHandler(ch)
setup_colored_logging()
if 1:
import warnings
def _custom_showwarning(msg, category, filename, lineno, file=None, line=None):
reset = "\x1b[0m"; c_file_line = _color.grey; c_cate = _color.orange; c_msg = _color.yellow
if LOG_debug_level:
formatted_message=f"{c_cate}{category.__name__}{reset}: {c_msg}{msg}{reset} {c_file_line}{filename}:{lineno}{reset}"
else: formatted_message = f"{c_cate}{category.__name__}{reset}: {c_msg}{msg}{reset}"
print(formatted_message)
warnings.showwarning = _custom_showwarning
if __name__=='__main__':
logging.warning("This is a warning message in yellow"); logging.error("This is an error message in red")
warnings.warn("This is a colored warning message")