Diffusers
Safetensors
icip_source_2 / launch.py
hansQAQ's picture
Upload folder using huggingface_hub
278bf35 verified
import os
import argparse
import contextlib
import logging
import os
import sys
import multiprocessing as mp
class ColoredFilter(logging.Filter):
"""
A logging filter to add color to certain log levels.
"""
RESET = "\033[0m"
RED = "\033[31m"
GREEN = "\033[32m"
YELLOW = "\033[33m"
BLUE = "\033[34m"
MAGENTA = "\033[35m"
CYAN = "\033[36m"
COLORS = {
"WARNING": YELLOW,
"INFO": GREEN,
"DEBUG": BLUE,
"CRITICAL": MAGENTA,
"ERROR": RED,
}
RESET = "\x1b[0m"
def __init__(self):
super().__init__()
def filter(self, record):
if record.levelname in self.COLORS:
color_start = self.COLORS[record.levelname]
record.levelname = f"{color_start}[{record.levelname}]"
record.msg = f"{record.msg}{self.RESET}"
return True
def main(args, extras) -> None:
# set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None)
env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else []
selected_gpus = [0]
# Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified.
# As far as Pytorch Lightning is concerned, we always use all available GPUs
# (possibly filtered by CUDA_VISIBLE_DEVICES).
devices = -1
if len(env_gpus) > 0:
# CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script.
n_gpus = len(env_gpus)
else:
selected_gpus = list(args.gpu.split(","))
n_gpus = len(selected_gpus)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
import pytorch_lightning as pl
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, DeviceStatsMonitor
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger
from pytorch_lightning.utilities.rank_zero import rank_zero_only
if args.typecheck:
from jaxtyping import install_import_hook
install_import_hook("mvdiff", "typeguard.typechecked")
from midi.systems.base import BaseSystem
from midi.utils.callbacks import (
CodeSnapshotCallback,
ConfigSnapshotCallback,
CustomProgressBar,
ProgressCallback,
)
from midi.utils.config import ExperimentConfig, load_config
from midi.utils.core import find
from midi.utils.misc import get_rank, time_recorder
from midi.utils.typing import Optional
logger = logging.getLogger("pytorch_lightning")
if args.verbose:
logger.setLevel(logging.DEBUG)
if args.benchmark:
time_recorder.enable(True)
for handler in logger.handlers:
if handler.stream == sys.stderr: # type: ignore
if not args.gradio:
handler.setFormatter(logging.Formatter("%(levelname)s %(message)s"))
handler.addFilter(ColoredFilter())
else:
handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
# parse YAML config to OmegaConf
cfg: ExperimentConfig
cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus)
# debug
if args.lr:
print(cfg.system)
cfg.system['optimizer']['args']['lr'] = args.lr
cfg.name = cfg.tag + f"_lr-{args.lr}"
dm = find(cfg.data_cls)(cfg.data)
system: BaseSystem = find(cfg.system_cls)(
cfg.system, resumed=cfg.resume is not None
)
system.set_save_dir(os.path.join(cfg.trial_dir, "save"))
if args.gradio:
fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs"))
fh.setLevel(logging.INFO)
if args.verbose:
fh.setLevel(logging.DEBUG)
fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s"))
logger.addHandler(fh)
callbacks = []
if args.train:
callbacks += [
ModelCheckpoint(
dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint
),
LearningRateMonitor(logging_interval="step"),
# CodeSnapshotCallback(
# os.path.join(cfg.trial_dir, "code"), use_version=False
# ),
ConfigSnapshotCallback(
args.config,
cfg,
os.path.join(cfg.trial_dir, "configs"),
use_version=False,
),
DeviceStatsMonitor()
]
if args.gradio:
callbacks += [
ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress"))
]
else:
callbacks += [CustomProgressBar(refresh_rate=1)]
def write_to_text(file, lines):
with open(file, "w") as f:
for line in lines:
f.write(line + "\n")
loggers = []
loggers += [
TensorBoardLogger(cfg.trial_dir, name="tb_logs"),
]
if args.wandb:
print("通天")
def get_wandb_safe_config(cfg):
"""提取WandB安全的配置项"""
safe_config = {}
# 只提取基本数据类型
basic_types = (str, int, float, bool, type(None))
for key, value in cfg.__dict__.items():
# 跳过私有属性
if key.startswith('_'):
continue
# 只保留基本数据类型
if isinstance(value, basic_types):
safe_config[key] = value
elif isinstance(value, (list, tuple)):
# 检查列表/元组中的元素是否都是基本类型
if all(isinstance(item, basic_types) for item in value):
safe_config[key] = value
elif isinstance(value, dict):
# 检查字典的键和值是否都是基本类型
if all(isinstance(k, basic_types) and isinstance(v, basic_types)
for k, v in value.items()):
safe_config[key] = value
else:
# 其他类型转换为字符串
safe_config[key] = str(value)
return safe_config
wandb_logger = WandbLogger(
project="MIDI-sketch",
name=f"{cfg.name}-{cfg.tag}",
save_code=True,
config=get_wandb_safe_config(cfg)
)
wandb_logger.experiment.save(args.config)
system._wandb_logger = wandb_logger
loggers += [wandb_logger]
if args.train:
# make tensorboard logging dir to suppress warning
rank_zero_only(
lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True)
)()
rank_zero_only(
lambda: write_to_text(
os.path.join(cfg.trial_dir, "cmd.txt"),
["python " + " ".join(sys.argv), str(args)],
)
)()
from pytorch_lightning.profilers import AdvancedProfiler
profiler = AdvancedProfiler(dirpath=".", filename="perf_logs")
trainer = Trainer(
# @TODO: Check how to parallel model to accelerate training process.
# overfit_batches=0.05,
# limit_val_batches=0.2,d
callbacks=callbacks,
logger=loggers,
inference_mode=False,
accelerator="gpu",
devices=devices,
profiler=profiler,
**cfg.trainer,
)
# set a different seed for each device
# NOTE: use trainer.global_rank instead of get_rank() to avoid getting the local rank
pl.seed_everything(cfg.seed + trainer.global_rank, workers=True)
def set_system_status(system: BaseSystem, ckpt_path: Optional[str]):
if ckpt_path is None:
return
ckpt = torch.load(ckpt_path, map_location="cpu")
system.set_resume_status(ckpt["epoch"], ckpt["global_step"])
if args.train:
trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume)
trainer.test(system, datamodule=dm)
if args.gradio:
# also export assets if in gradio mode
trainer.predict(system, datamodule=dm)
elif args.validate:
# manually set epoch and global_step as they cannot be automatically resumed
set_system_status(system, cfg.resume)
trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume)
elif args.test:
# manually set epoch and global_step as they cannot be automatically resumed
set_system_status(system, cfg.resume)
trainer.test(system, datamodule=dm, ckpt_path=cfg.resume)
elif args.export:
set_system_status(system, cfg.resume)
trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume)
if __name__ == "__main__":
mp.set_start_method('spawn', force=True)
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True, help="path to config file")
parser.add_argument(
"--gpu",
default="0",
help="GPU(s) to be used. 0 means use the 1st available GPU. "
"1,2 means use the 2nd and 3rd available GPU. "
"If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, "
"this argument is ignored and all available GPUs are always used.",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument("--train", action="store_true")
group.add_argument("--validate", action="store_true")
group.add_argument("--test", action="store_true")
group.add_argument("--export", action="store_true")
parser.add_argument("--wandb", action="store_true", help="if true, log to wandb")
parser.add_argument(
"--gradio", action="store_true", help="if true, run in gradio mode"
)
parser.add_argument(
"--verbose", action="store_true", help="if true, set logging level to DEBUG"
)
parser.add_argument(
"--benchmark",
action="store_true",
help="if true, set to benchmark mode to record running times",
)
parser.add_argument(
"--typecheck",
action="store_true",
help="whether to enable dynamic type checking",
)
# debug use
parser.add_argument(
"--lr",
type=float,
)
args, extras = parser.parse_known_args()
if args.gradio:
# FIXME: no effect, stdout is not captured
with contextlib.redirect_stdout(sys.stderr):
main(args, extras)
else:
import torch
torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False)
main(args, extras)