import os import argparse import contextlib import logging import os import sys import multiprocessing as mp class ColoredFilter(logging.Filter): """ A logging filter to add color to certain log levels. """ RESET = "\033[0m" RED = "\033[31m" GREEN = "\033[32m" YELLOW = "\033[33m" BLUE = "\033[34m" MAGENTA = "\033[35m" CYAN = "\033[36m" COLORS = { "WARNING": YELLOW, "INFO": GREEN, "DEBUG": BLUE, "CRITICAL": MAGENTA, "ERROR": RED, } RESET = "\x1b[0m" def __init__(self): super().__init__() def filter(self, record): if record.levelname in self.COLORS: color_start = self.COLORS[record.levelname] record.levelname = f"{color_start}[{record.levelname}]" record.msg = f"{record.msg}{self.RESET}" return True def main(args, extras) -> None: # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] selected_gpus = [0] # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified. # As far as Pytorch Lightning is concerned, we always use all available GPUs # (possibly filtered by CUDA_VISIBLE_DEVICES). devices = -1 if len(env_gpus) > 0: # CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script. n_gpus = len(env_gpus) else: selected_gpus = list(args.gpu.split(",")) n_gpus = len(selected_gpus) os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu import pytorch_lightning as pl import torch from pytorch_lightning import Trainer from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, DeviceStatsMonitor from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger from pytorch_lightning.utilities.rank_zero import rank_zero_only if args.typecheck: from jaxtyping import install_import_hook install_import_hook("mvdiff", "typeguard.typechecked") from midi.systems.base import BaseSystem from midi.utils.callbacks import ( CodeSnapshotCallback, ConfigSnapshotCallback, CustomProgressBar, ProgressCallback, ) from midi.utils.config import ExperimentConfig, load_config from midi.utils.core import find from midi.utils.misc import get_rank, time_recorder from midi.utils.typing import Optional logger = logging.getLogger("pytorch_lightning") if args.verbose: logger.setLevel(logging.DEBUG) if args.benchmark: time_recorder.enable(True) for handler in logger.handlers: if handler.stream == sys.stderr: # type: ignore if not args.gradio: handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) handler.addFilter(ColoredFilter()) else: handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) # parse YAML config to OmegaConf cfg: ExperimentConfig cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) # debug if args.lr: print(cfg.system) cfg.system['optimizer']['args']['lr'] = args.lr cfg.name = cfg.tag + f"_lr-{args.lr}" dm = find(cfg.data_cls)(cfg.data) system: BaseSystem = find(cfg.system_cls)( cfg.system, resumed=cfg.resume is not None ) system.set_save_dir(os.path.join(cfg.trial_dir, "save")) if args.gradio: fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs")) fh.setLevel(logging.INFO) if args.verbose: fh.setLevel(logging.DEBUG) fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) logger.addHandler(fh) callbacks = [] if args.train: callbacks += [ ModelCheckpoint( dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint ), LearningRateMonitor(logging_interval="step"), # CodeSnapshotCallback( # os.path.join(cfg.trial_dir, "code"), use_version=False # ), ConfigSnapshotCallback( args.config, cfg, os.path.join(cfg.trial_dir, "configs"), use_version=False, ), DeviceStatsMonitor() ] if args.gradio: callbacks += [ ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress")) ] else: callbacks += [CustomProgressBar(refresh_rate=1)] def write_to_text(file, lines): with open(file, "w") as f: for line in lines: f.write(line + "\n") loggers = [] loggers += [ TensorBoardLogger(cfg.trial_dir, name="tb_logs"), ] if args.wandb: print("通天") def get_wandb_safe_config(cfg): """提取WandB安全的配置项""" safe_config = {} # 只提取基本数据类型 basic_types = (str, int, float, bool, type(None)) for key, value in cfg.__dict__.items(): # 跳过私有属性 if key.startswith('_'): continue # 只保留基本数据类型 if isinstance(value, basic_types): safe_config[key] = value elif isinstance(value, (list, tuple)): # 检查列表/元组中的元素是否都是基本类型 if all(isinstance(item, basic_types) for item in value): safe_config[key] = value elif isinstance(value, dict): # 检查字典的键和值是否都是基本类型 if all(isinstance(k, basic_types) and isinstance(v, basic_types) for k, v in value.items()): safe_config[key] = value else: # 其他类型转换为字符串 safe_config[key] = str(value) return safe_config wandb_logger = WandbLogger( project="MIDI-sketch", name=f"{cfg.name}-{cfg.tag}", save_code=True, config=get_wandb_safe_config(cfg) ) wandb_logger.experiment.save(args.config) system._wandb_logger = wandb_logger loggers += [wandb_logger] if args.train: # make tensorboard logging dir to suppress warning rank_zero_only( lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) )() rank_zero_only( lambda: write_to_text( os.path.join(cfg.trial_dir, "cmd.txt"), ["python " + " ".join(sys.argv), str(args)], ) )() from pytorch_lightning.profilers import AdvancedProfiler profiler = AdvancedProfiler(dirpath=".", filename="perf_logs") trainer = Trainer( # @TODO: Check how to parallel model to accelerate training process. # overfit_batches=0.05, # limit_val_batches=0.2,d callbacks=callbacks, logger=loggers, inference_mode=False, accelerator="gpu", devices=devices, profiler=profiler, **cfg.trainer, ) # set a different seed for each device # NOTE: use trainer.global_rank instead of get_rank() to avoid getting the local rank pl.seed_everything(cfg.seed + trainer.global_rank, workers=True) def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): if ckpt_path is None: return ckpt = torch.load(ckpt_path, map_location="cpu") system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) if args.train: trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) trainer.test(system, datamodule=dm) if args.gradio: # also export assets if in gradio mode trainer.predict(system, datamodule=dm) elif args.validate: # manually set epoch and global_step as they cannot be automatically resumed set_system_status(system, cfg.resume) trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) elif args.test: # manually set epoch and global_step as they cannot be automatically resumed set_system_status(system, cfg.resume) trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) elif args.export: set_system_status(system, cfg.resume) trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) if __name__ == "__main__": mp.set_start_method('spawn', force=True) parser = argparse.ArgumentParser() parser.add_argument("--config", required=True, help="path to config file") parser.add_argument( "--gpu", default="0", help="GPU(s) to be used. 0 means use the 1st available GPU. " "1,2 means use the 2nd and 3rd available GPU. " "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " "this argument is ignored and all available GPUs are always used.", ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--train", action="store_true") group.add_argument("--validate", action="store_true") group.add_argument("--test", action="store_true") group.add_argument("--export", action="store_true") parser.add_argument("--wandb", action="store_true", help="if true, log to wandb") parser.add_argument( "--gradio", action="store_true", help="if true, run in gradio mode" ) parser.add_argument( "--verbose", action="store_true", help="if true, set logging level to DEBUG" ) parser.add_argument( "--benchmark", action="store_true", help="if true, set to benchmark mode to record running times", ) parser.add_argument( "--typecheck", action="store_true", help="whether to enable dynamic type checking", ) # debug use parser.add_argument( "--lr", type=float, ) args, extras = parser.parse_known_args() if args.gradio: # FIXME: no effect, stdout is not captured with contextlib.redirect_stdout(sys.stderr): main(args, extras) else: import torch torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False) main(args, extras)