VideoVAEPlus-tactile / utils /train_utils.py
WitneyWW's picture
Add source, configs, inference scripts
e7c18b3 verified
Raw
History Blame Contribute Delete
9.36 kB
import os
from omegaconf import OmegaConf
from collections import OrderedDict
import logging
mainlogger = logging.getLogger("mainlogger")
import torch
from collections import OrderedDict
def init_workspace(name, logdir, model_config, lightning_config, rank=0):
workdir = os.path.join(logdir, name)
ckptdir = os.path.join(workdir, "checkpoints")
cfgdir = os.path.join(workdir, "configs")
loginfo = os.path.join(workdir, "loginfo")
# Create logdirs and save configs (all ranks will do to avoid missing directory error if rank:0 is slower)
os.makedirs(workdir, exist_ok=True)
os.makedirs(ckptdir, exist_ok=True)
os.makedirs(cfgdir, exist_ok=True)
os.makedirs(loginfo, exist_ok=True)
if rank == 0:
if (
"callbacks" in lightning_config
and "metrics_over_trainsteps_checkpoint" in lightning_config.callbacks
):
os.makedirs(os.path.join(ckptdir, "trainstep_checkpoints"), exist_ok=True)
OmegaConf.save(model_config, os.path.join(cfgdir, "model.yaml"))
OmegaConf.save(
OmegaConf.create({"lightning": lightning_config}),
os.path.join(cfgdir, "lightning.yaml"),
)
return workdir, ckptdir, cfgdir, loginfo
def check_config_attribute(config, name):
if name in config:
value = getattr(config, name)
return value
else:
return None
def get_trainer_callbacks(lightning_config, config, logdir, ckptdir, logger):
default_callbacks_cfg = {
"model_checkpoint": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch}",
"verbose": True,
"save_last": True,
},
},
"batch_logger": {
"target": "utils.callbacks.ImageLogger",
"params": {
"save_dir": logdir,
"batch_frequency": 1000,
"max_images": 4,
"clamp": True,
},
},
"learning_rate_logger": {
"target": "pytorch_lightning.callbacks.LearningRateMonitor",
"params": {"logging_interval": "step", "log_momentum": False},
},
"cuda_callback": {"target": "utils.callbacks.CUDACallback"},
}
## optional setting for saving checkpoints
monitor_metric = check_config_attribute(config.model.params, "monitor")
if monitor_metric is not None:
mainlogger.info(f"Monitoring {monitor_metric} as checkpoint metric.")
default_callbacks_cfg["model_checkpoint"]["params"]["monitor"] = monitor_metric
default_callbacks_cfg["model_checkpoint"]["params"]["save_top_k"] = 3
default_callbacks_cfg["model_checkpoint"]["params"]["mode"] = "min"
if "metrics_over_trainsteps_checkpoint" in lightning_config.callbacks:
mainlogger.info(
"Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
)
default_metrics_over_trainsteps_ckpt_dict = {
"metrics_over_trainsteps_checkpoint": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
"filename": "{epoch}-{step}",
"verbose": True,
"save_top_k": -1,
"every_n_train_steps": 10000,
"save_weights_only": True,
},
}
}
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
return callbacks_cfg
def get_trainer_logger(lightning_config, logdir, on_debug):
default_logger_cfgs = {
"tensorboard": {
"target": "pytorch_lightning.loggers.TensorBoardLogger",
"params": {
"save_dir": logdir,
"name": "tensorboard",
},
},
"testtube": {
"target": "pytorch_lightning.loggers.CSVLogger",
"params": {
"name": "testtube",
"save_dir": logdir,
},
},
}
os.makedirs(os.path.join(logdir, "tensorboard"), exist_ok=True)
default_logger_cfg = default_logger_cfgs["tensorboard"]
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
return logger_cfg
def get_trainer_strategy(lightning_config):
default_strategy_dict = {
"target": "pytorch_lightning.strategies.DDPShardedStrategy"
}
if "strategy" in lightning_config:
strategy_cfg = lightning_config.strategy
return strategy_cfg
else:
strategy_cfg = OmegaConf.create()
strategy_cfg = OmegaConf.merge(default_strategy_dict, strategy_cfg)
return strategy_cfg
def load_checkpoints(model, model_cfg):
## special load setting for adapter training
if check_config_attribute(model_cfg, "adapter_only"):
pretrained_ckpt = model_cfg.pretrained_checkpoint
assert os.path.exists(pretrained_ckpt), (
"Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt
)
mainlogger.info(
">>> Load weights from pretrained checkpoint (training adapter only)"
)
print(f"Loading model from {pretrained_ckpt}")
## only load weight for the backbone model (e.g. latent diffusion model)
state_dict = torch.load(pretrained_ckpt, map_location=f"cpu")
if "state_dict" in list(state_dict.keys()):
state_dict = state_dict["state_dict"]
else:
# deepspeed
dp_state_dict = OrderedDict()
for key in state_dict["module"].keys():
dp_state_dict[key[16:]] = state_dict["module"][key]
state_dict = dp_state_dict
model.load_state_dict(state_dict, strict=False)
model.empty_paras = None
return model
empty_paras = None
if check_config_attribute(model_cfg, "pretrained_checkpoint"):
pretrained_ckpt = model_cfg.pretrained_checkpoint
assert os.path.exists(pretrained_ckpt), (
"Error: Pre-trained checkpoint NOT found at:%s" % pretrained_ckpt
)
mainlogger.info(">>> Load weights from pretrained checkpoint")
# mainlogger.info(pretrained_ckpt)
print("Loading model from {pretrained_ckpt}")
pl_sd = torch.load(pretrained_ckpt, map_location="cpu")
try:
if "state_dict" in pl_sd.keys():
model.load_state_dict(pl_sd["state_dict"])
else:
# deepspeed
new_pl_sd = OrderedDict()
for key in pl_sd["module"].keys():
new_pl_sd[key[16:]] = pl_sd["module"][key]
model.load_state_dict(new_pl_sd)
except:
model.load_state_dict(pl_sd)
else:
empty_paras = None
## record empty params
model.empty_paras = empty_paras
return model
def get_autoresume_path(logdir):
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
if os.path.exists(ckpt):
try:
tmp = torch.load(ckpt, map_location="cpu")
e = tmp["epoch"]
gs = tmp["global_step"]
mainlogger.info(f"[INFO] Resume from epoch {e}, global step {gs}!")
del tmp
except:
try:
mainlogger.info("Load last.ckpt failed!")
ckpts = sorted(
[
f
for f in os.listdir(os.path.join(logdir, "checkpoints"))
if not os.path.isdir(f)
]
)
mainlogger.info(f"all avaible checkpoints: {ckpts}")
ckpts.remove("last.ckpt")
if "trainstep_checkpoints" in ckpts:
ckpts.remove("trainstep_checkpoints")
ckpt_path = ckpts[-1]
ckpt = os.path.join(logdir, "checkpoints", ckpt_path)
mainlogger.info(f"Select resuming ckpt: {ckpt}")
except ValueError:
mainlogger.info("Load last.ckpt failed! and there is no other ckpts")
resume_checkpt_path = ckpt
mainlogger.info(f"[INFO] resume from: {ckpt}")
else:
resume_checkpt_path = None
mainlogger.info(
f"[INFO] no checkpoint found in current workspace: {os.path.join(logdir, 'checkpoints')}"
)
return resume_checkpt_path
def set_logger(logfile, name="mainlogger"):
logger = logging.getLogger(name)
logger.setLevel(logging.INFO)
fh = logging.FileHandler(logfile, mode="w")
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
fh.setFormatter(logging.Formatter("%(asctime)s-%(levelname)s: %(message)s"))
ch.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(fh)
logger.addHandler(ch)
return logger