Instructions to use hansQAQ/icip_source_2 with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use hansQAQ/icip_source_2 with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("hansQAQ/icip_source_2", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| import os | |
| import argparse | |
| import contextlib | |
| import logging | |
| import os | |
| import sys | |
| import multiprocessing as mp | |
| class ColoredFilter(logging.Filter): | |
| """ | |
| A logging filter to add color to certain log levels. | |
| """ | |
| RESET = "\033[0m" | |
| RED = "\033[31m" | |
| GREEN = "\033[32m" | |
| YELLOW = "\033[33m" | |
| BLUE = "\033[34m" | |
| MAGENTA = "\033[35m" | |
| CYAN = "\033[36m" | |
| COLORS = { | |
| "WARNING": YELLOW, | |
| "INFO": GREEN, | |
| "DEBUG": BLUE, | |
| "CRITICAL": MAGENTA, | |
| "ERROR": RED, | |
| } | |
| RESET = "\x1b[0m" | |
| def __init__(self): | |
| super().__init__() | |
| def filter(self, record): | |
| if record.levelname in self.COLORS: | |
| color_start = self.COLORS[record.levelname] | |
| record.levelname = f"{color_start}[{record.levelname}]" | |
| record.msg = f"{record.msg}{self.RESET}" | |
| return True | |
| def main(args, extras) -> None: | |
| # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning | |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
| env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) | |
| env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] | |
| selected_gpus = [0] | |
| # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified. | |
| # As far as Pytorch Lightning is concerned, we always use all available GPUs | |
| # (possibly filtered by CUDA_VISIBLE_DEVICES). | |
| devices = -1 | |
| if len(env_gpus) > 0: | |
| # CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script. | |
| n_gpus = len(env_gpus) | |
| else: | |
| selected_gpus = list(args.gpu.split(",")) | |
| n_gpus = len(selected_gpus) | |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu | |
| import pytorch_lightning as pl | |
| import torch | |
| from pytorch_lightning import Trainer | |
| from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, DeviceStatsMonitor | |
| from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger | |
| from pytorch_lightning.utilities.rank_zero import rank_zero_only | |
| if args.typecheck: | |
| from jaxtyping import install_import_hook | |
| install_import_hook("mvdiff", "typeguard.typechecked") | |
| from midi.systems.base import BaseSystem | |
| from midi.utils.callbacks import ( | |
| CodeSnapshotCallback, | |
| ConfigSnapshotCallback, | |
| CustomProgressBar, | |
| ProgressCallback, | |
| ) | |
| from midi.utils.config import ExperimentConfig, load_config | |
| from midi.utils.core import find | |
| from midi.utils.misc import get_rank, time_recorder | |
| from midi.utils.typing import Optional | |
| logger = logging.getLogger("pytorch_lightning") | |
| if args.verbose: | |
| logger.setLevel(logging.DEBUG) | |
| if args.benchmark: | |
| time_recorder.enable(True) | |
| for handler in logger.handlers: | |
| if handler.stream == sys.stderr: # type: ignore | |
| if not args.gradio: | |
| handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) | |
| handler.addFilter(ColoredFilter()) | |
| else: | |
| handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) | |
| # parse YAML config to OmegaConf | |
| cfg: ExperimentConfig | |
| cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) | |
| # debug | |
| if args.lr: | |
| print(cfg.system) | |
| cfg.system['optimizer']['args']['lr'] = args.lr | |
| cfg.name = cfg.tag + f"_lr-{args.lr}" | |
| dm = find(cfg.data_cls)(cfg.data) | |
| system: BaseSystem = find(cfg.system_cls)( | |
| cfg.system, resumed=cfg.resume is not None | |
| ) | |
| system.set_save_dir(os.path.join(cfg.trial_dir, "save")) | |
| if args.gradio: | |
| fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs")) | |
| fh.setLevel(logging.INFO) | |
| if args.verbose: | |
| fh.setLevel(logging.DEBUG) | |
| fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) | |
| logger.addHandler(fh) | |
| callbacks = [] | |
| if args.train: | |
| callbacks += [ | |
| ModelCheckpoint( | |
| dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint | |
| ), | |
| LearningRateMonitor(logging_interval="step"), | |
| # CodeSnapshotCallback( | |
| # os.path.join(cfg.trial_dir, "code"), use_version=False | |
| # ), | |
| ConfigSnapshotCallback( | |
| args.config, | |
| cfg, | |
| os.path.join(cfg.trial_dir, "configs"), | |
| use_version=False, | |
| ), | |
| DeviceStatsMonitor() | |
| ] | |
| if args.gradio: | |
| callbacks += [ | |
| ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress")) | |
| ] | |
| else: | |
| callbacks += [CustomProgressBar(refresh_rate=1)] | |
| def write_to_text(file, lines): | |
| with open(file, "w") as f: | |
| for line in lines: | |
| f.write(line + "\n") | |
| loggers = [] | |
| loggers += [ | |
| TensorBoardLogger(cfg.trial_dir, name="tb_logs"), | |
| ] | |
| if args.wandb: | |
| print("通天") | |
| def get_wandb_safe_config(cfg): | |
| """提取WandB安全的配置项""" | |
| safe_config = {} | |
| # 只提取基本数据类型 | |
| basic_types = (str, int, float, bool, type(None)) | |
| for key, value in cfg.__dict__.items(): | |
| # 跳过私有属性 | |
| if key.startswith('_'): | |
| continue | |
| # 只保留基本数据类型 | |
| if isinstance(value, basic_types): | |
| safe_config[key] = value | |
| elif isinstance(value, (list, tuple)): | |
| # 检查列表/元组中的元素是否都是基本类型 | |
| if all(isinstance(item, basic_types) for item in value): | |
| safe_config[key] = value | |
| elif isinstance(value, dict): | |
| # 检查字典的键和值是否都是基本类型 | |
| if all(isinstance(k, basic_types) and isinstance(v, basic_types) | |
| for k, v in value.items()): | |
| safe_config[key] = value | |
| else: | |
| # 其他类型转换为字符串 | |
| safe_config[key] = str(value) | |
| return safe_config | |
| wandb_logger = WandbLogger( | |
| project="MIDI-sketch", | |
| name=f"{cfg.name}-{cfg.tag}", | |
| save_code=True, | |
| config=get_wandb_safe_config(cfg) | |
| ) | |
| wandb_logger.experiment.save(args.config) | |
| system._wandb_logger = wandb_logger | |
| loggers += [wandb_logger] | |
| if args.train: | |
| # make tensorboard logging dir to suppress warning | |
| rank_zero_only( | |
| lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) | |
| )() | |
| rank_zero_only( | |
| lambda: write_to_text( | |
| os.path.join(cfg.trial_dir, "cmd.txt"), | |
| ["python " + " ".join(sys.argv), str(args)], | |
| ) | |
| )() | |
| from pytorch_lightning.profilers import AdvancedProfiler | |
| profiler = AdvancedProfiler(dirpath=".", filename="perf_logs") | |
| trainer = Trainer( | |
| # @TODO: Check how to parallel model to accelerate training process. | |
| # overfit_batches=0.05, | |
| # limit_val_batches=0.2,d | |
| callbacks=callbacks, | |
| logger=loggers, | |
| inference_mode=False, | |
| accelerator="gpu", | |
| devices=devices, | |
| profiler=profiler, | |
| **cfg.trainer, | |
| ) | |
| # set a different seed for each device | |
| # NOTE: use trainer.global_rank instead of get_rank() to avoid getting the local rank | |
| pl.seed_everything(cfg.seed + trainer.global_rank, workers=True) | |
| def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): | |
| if ckpt_path is None: | |
| return | |
| ckpt = torch.load(ckpt_path, map_location="cpu") | |
| system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) | |
| if args.train: | |
| trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) | |
| trainer.test(system, datamodule=dm) | |
| if args.gradio: | |
| # also export assets if in gradio mode | |
| trainer.predict(system, datamodule=dm) | |
| elif args.validate: | |
| # manually set epoch and global_step as they cannot be automatically resumed | |
| set_system_status(system, cfg.resume) | |
| trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) | |
| elif args.test: | |
| # manually set epoch and global_step as they cannot be automatically resumed | |
| set_system_status(system, cfg.resume) | |
| trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) | |
| elif args.export: | |
| set_system_status(system, cfg.resume) | |
| trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) | |
| if __name__ == "__main__": | |
| mp.set_start_method('spawn', force=True) | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--config", required=True, help="path to config file") | |
| parser.add_argument( | |
| "--gpu", | |
| default="0", | |
| help="GPU(s) to be used. 0 means use the 1st available GPU. " | |
| "1,2 means use the 2nd and 3rd available GPU. " | |
| "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " | |
| "this argument is ignored and all available GPUs are always used.", | |
| ) | |
| group = parser.add_mutually_exclusive_group(required=True) | |
| group.add_argument("--train", action="store_true") | |
| group.add_argument("--validate", action="store_true") | |
| group.add_argument("--test", action="store_true") | |
| group.add_argument("--export", action="store_true") | |
| parser.add_argument("--wandb", action="store_true", help="if true, log to wandb") | |
| parser.add_argument( | |
| "--gradio", action="store_true", help="if true, run in gradio mode" | |
| ) | |
| parser.add_argument( | |
| "--verbose", action="store_true", help="if true, set logging level to DEBUG" | |
| ) | |
| parser.add_argument( | |
| "--benchmark", | |
| action="store_true", | |
| help="if true, set to benchmark mode to record running times", | |
| ) | |
| parser.add_argument( | |
| "--typecheck", | |
| action="store_true", | |
| help="whether to enable dynamic type checking", | |
| ) | |
| # debug use | |
| parser.add_argument( | |
| "--lr", | |
| type=float, | |
| ) | |
| args, extras = parser.parse_known_args() | |
| if args.gradio: | |
| # FIXME: no effect, stdout is not captured | |
| with contextlib.redirect_stdout(sys.stderr): | |
| main(args, extras) | |
| else: | |
| import torch | |
| torch.backends.cuda.sdp_kernel(enable_flash=True, enable_mem_efficient=False, enable_math=False) | |
| main(args, extras) | |