| """ |
| This repo is forked from [Boyuan Chen](https://boyuan.space/)'s research |
| template [repo](https://github.com/buoyancy99/research-template). |
| By its MIT license, you must keep the above sentence in `README.md` |
| and the `LICENSE` file to credit the author. |
| """ |
|
|
| from abc import ABC, abstractmethod |
| from typing import Optional, Union, Literal, List, Dict |
| import pathlib |
| import os |
|
|
| import hydra |
| import torch |
| from lightning.pytorch.strategies.ddp import DDPStrategy |
|
|
| import lightning.pytorch as pl |
| from lightning.pytorch.loggers.wandb import WandbLogger |
| from lightning.pytorch.utilities.types import TRAIN_DATALOADERS |
| from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint |
| from pytorch_lightning.utilities import rank_zero_info |
|
|
| from omegaconf import DictConfig |
|
|
| from utils.print_utils import cyan |
| from utils.distributed_utils import is_rank_zero |
| from safetensors.torch import load_file, load_model |
| from pathlib import Path |
| from huggingface_hub import hf_hub_download |
| from huggingface_hub import model_info |
|
|
| torch.set_float32_matmul_precision("high") |
|
|
| def is_huggingface_model(path: str) -> bool: |
| hf_ckpt = str(path).split('/') |
| repo_id = '/'.join(hf_ckpt[:2]) |
| try: |
| model_info(repo_id) |
| return True |
| except: |
| return False |
| |
| def _extract_state_dict(checkpoint): |
| if isinstance(checkpoint, dict) and isinstance(checkpoint.get("state_dict"), dict): |
| return checkpoint["state_dict"] |
| return checkpoint |
|
|
|
|
| def _compatible_state_dict(algo, state_dict): |
| model_state = algo.state_dict() |
| prefixes = ( |
| "", |
| "model.", |
| "module.", |
| "algo.", |
| "diffusion_model.", |
| "diffusion_model.model.", |
| "vae.", |
| ) |
| best = ( |
| {}, |
| { |
| "matched": 0, |
| "missing": 0, |
| "unexpected": 0, |
| "skipped_prefix": 0, |
| "shape_mismatch": 0, |
| "missing_model": len(model_state), |
| "total": len(state_dict), |
| "prefix": "", |
| "unexpected_checkpoint_keys": [], |
| "shape_mismatch_keys": [], |
| "missing_model_keys": list(model_state.keys()), |
| }, |
| ) |
| for prefix in prefixes: |
| compatible = {} |
| unexpected_checkpoint_keys = [] |
| shape_mismatch_keys = [] |
| skipped_prefix = 0 |
| for key, value in state_dict.items(): |
| if key in ["data_mean", "data_std"]: |
| continue |
| if prefix and not key.startswith(prefix): |
| skipped_prefix += 1 |
| continue |
| stripped = key.removeprefix(prefix) |
| if stripped not in model_state: |
| unexpected_checkpoint_keys.append(key) |
| continue |
| if hasattr(value, "shape") and value.shape != model_state[stripped].shape: |
| shape_mismatch_keys.append( |
| ( |
| key, |
| stripped, |
| tuple(value.shape), |
| tuple(model_state[stripped].shape), |
| ) |
| ) |
| continue |
| compatible[stripped] = value |
| missing_model_keys = [key for key in model_state.keys() if key not in compatible] |
| if len(compatible) > best[1]["matched"]: |
| best = ( |
| compatible, |
| { |
| "matched": len(compatible), |
| "missing": skipped_prefix + len(unexpected_checkpoint_keys), |
| "unexpected": len(unexpected_checkpoint_keys), |
| "skipped_prefix": skipped_prefix, |
| "shape_mismatch": len(shape_mismatch_keys), |
| "missing_model": len(missing_model_keys), |
| "total": len(state_dict), |
| "prefix": prefix, |
| "unexpected_checkpoint_keys": unexpected_checkpoint_keys, |
| "shape_mismatch_keys": shape_mismatch_keys, |
| "missing_model_keys": missing_model_keys, |
| }, |
| ) |
| return best |
|
|
|
|
| def _key_matches_any(key: str, markers: tuple[str, ...]) -> bool: |
| return any(marker in key for marker in markers) |
|
|
|
|
| def _format_key_samples(keys, limit: int = 25, indent: str = " ") -> str: |
| sample = list(keys[:limit]) |
| lines = [f"{indent}{key}" for key in sample] |
| if len(keys) > limit: |
| lines.append(f"{indent}... {len(keys) - limit} more") |
| return "\n".join(lines) |
|
|
|
|
| def _format_shape_mismatch_samples(shape_mismatch_keys, limit: int = 25, indent: str = " ") -> str: |
| sample = list(shape_mismatch_keys[:limit]) |
| lines = [ |
| f"{indent}{checkpoint_key} -> {model_key}: checkpoint{checkpoint_shape} != model{model_shape}" |
| for checkpoint_key, model_key, checkpoint_shape, model_shape in sample |
| ] |
| if len(shape_mismatch_keys) > limit: |
| lines.append(f"{indent}... {len(shape_mismatch_keys) - limit} more") |
| return "\n".join(lines) |
|
|
|
|
| def _log_checkpoint_mismatch_report( |
| stats, |
| checkpoint_path, |
| label: str | None = None, |
| dememwm_key_check: bool = False, |
| ) -> None: |
| mismatch_count = stats["missing_model"] + stats["unexpected"] + stats["shape_mismatch"] |
| if mismatch_count == 0: |
| return |
|
|
| title = label or str(checkpoint_path) |
| lines = [ |
| f"Checkpoint mismatch report for {title}:", |
| f" checkpoint={checkpoint_path}", |
| " selected_prefix={!r}".format(stats["prefix"]), |
| ( |
| " counts: " |
| "matched={matched} " |
| "model_not_loaded={model_not_loaded} " |
| "checkpoint_not_used={checkpoint_not_used} " |
| "shape_mismatch={shape_mismatch} " |
| "skipped_by_prefix={skipped_by_prefix}" |
| ).format( |
| matched=stats["matched"], |
| model_not_loaded=stats["missing_model"], |
| checkpoint_not_used=stats["unexpected"], |
| shape_mismatch=stats["shape_mismatch"], |
| skipped_by_prefix=stats["skipped_prefix"], |
| ), |
| ] |
| if stats["missing_model_keys"]: |
| lines.append(" Model keys not loaded from checkpoint:") |
| lines.append(_format_key_samples(stats["missing_model_keys"])) |
| if stats["unexpected_checkpoint_keys"]: |
| lines.append(" Checkpoint keys not used by current model:") |
| lines.append(_format_key_samples(stats["unexpected_checkpoint_keys"])) |
| if stats["shape_mismatch_keys"]: |
| lines.append(" Shape mismatches:") |
| lines.append(_format_shape_mismatch_samples(stats["shape_mismatch_keys"])) |
|
|
| if dememwm_key_check: |
| markers = ("dememwm_", ".dememwm_", ".memory_token_cross_attn.") |
| missing_dememwm = [key for key in stats["missing_model_keys"] if _key_matches_any(key, markers)] |
| unexpected_dememwm = [ |
| key for key in stats["unexpected_checkpoint_keys"] if _key_matches_any(key, markers) |
| ] |
| shape_dememwm = [ |
| item |
| for item in stats["shape_mismatch_keys"] |
| if _key_matches_any(item[0], markers) or _key_matches_any(item[1], markers) |
| ] |
| if missing_dememwm or unexpected_dememwm or shape_dememwm: |
| lines.append(" DeMemWM mismatch subset:") |
| if missing_dememwm: |
| lines.append(" DeMemWM model keys not loaded:") |
| lines.append(_format_key_samples(missing_dememwm, indent=" ")) |
| if unexpected_dememwm: |
| lines.append(" DeMemWM checkpoint keys not used:") |
| lines.append(_format_key_samples(unexpected_dememwm, indent=" ")) |
| if shape_dememwm: |
| lines.append(" DeMemWM shape mismatches:") |
| lines.append(_format_shape_mismatch_samples(shape_dememwm, indent=" ")) |
|
|
| rank_zero_info("\n".join(lines)) |
|
|
|
|
| def load_custom_checkpoint( |
| algo, |
| checkpoint_path, |
| require_match: bool = False, |
| label: str | None = None, |
| dememwm_key_check: bool = False, |
| report_key_mismatch: bool = False, |
| ): |
| if not checkpoint_path: |
| if require_match: |
| target = label or "model" |
| raise FileNotFoundError(f"Expected checkpoint for {target}, but no path was provided.") |
| rank_zero_info("No checkpoint path provided, skipping checkpoint loading.") |
| return None |
|
|
| if not isinstance(checkpoint_path, Path): |
| checkpoint_path = Path(checkpoint_path) |
|
|
| if is_huggingface_model(str(checkpoint_path)): |
| hf_ckpt = str(checkpoint_path).split("/") |
| repo_id = "/".join(hf_ckpt[:2]) |
| file_name = "/".join(hf_ckpt[2:]) |
| model_path = hf_hub_download(repo_id=repo_id, filename=file_name) |
| ckpt = torch.load(model_path, map_location=torch.device("cpu")) |
| state_dict = _extract_state_dict(ckpt) |
|
|
| elif checkpoint_path.suffix == ".pt": |
| ckpt = torch.load(checkpoint_path, weights_only=True) |
| state_dict = _extract_state_dict(ckpt) |
|
|
| elif checkpoint_path.suffix == ".ckpt": |
| ckpt = torch.load(checkpoint_path, map_location=torch.device("cpu")) |
| state_dict = _extract_state_dict(ckpt) |
|
|
| elif checkpoint_path.suffix == ".safetensors": |
| state_dict = load_file(checkpoint_path) |
|
|
| elif os.path.isdir(checkpoint_path): |
| ckpt_files = [f for f in os.listdir(checkpoint_path) if f.endswith(".ckpt")] |
| if not ckpt_files: |
| raise FileNotFoundError("No .ckpt files found in the specified directory!") |
| selected_ckpt = max(ckpt_files) |
| selected_ckpt_path = os.path.join(checkpoint_path, selected_ckpt) |
| print(f"Checkpoint file selected for loading: {selected_ckpt_path}") |
|
|
| ckpt = torch.load(selected_ckpt_path, map_location=torch.device("cpu")) |
| state_dict = _extract_state_dict(ckpt) |
|
|
| else: |
| raise ValueError(f"Unsupported checkpoint: {checkpoint_path}") |
| if dememwm_key_check and hasattr(algo, "strict_dememwm_checkpoint_key_check"): |
| algo.strict_dememwm_checkpoint_key_check(state_dict) |
| compatible, stats = _compatible_state_dict(algo, state_dict) |
| if require_match and stats["matched"] == 0: |
| raise RuntimeError(f"Expected checkpoint for {label or checkpoint_path} matched zero model weights.") |
| if compatible: |
| algo.load_state_dict(compatible, strict=False) |
| elif checkpoint_path.suffix == ".safetensors": |
| load_model(algo, checkpoint_path, strict=False) |
| rank_zero_info( |
| "Model weights loaded from {}: matched={} missing={} shape_mismatch={} prefix={!r}".format( |
| checkpoint_path, |
| stats["matched"], |
| stats["missing"], |
| stats["shape_mismatch"], |
| stats["prefix"], |
| ) |
| ) |
| if report_key_mismatch: |
| _log_checkpoint_mismatch_report(stats, checkpoint_path, label=label, dememwm_key_check=dememwm_key_check) |
| return stats |
|
|
|
|
| class BaseExperiment(ABC): |
| """ |
| Abstract class for an experiment. This generalizes the pytorch lightning Trainer & lightning Module to more |
| flexible experiments that doesn't fit in the typical ml loop, e.g. multi-stage reinforcement learning benchmarks. |
| """ |
|
|
| |
| compatible_algorithms: Dict = NotImplementedError |
|
|
| def __init__( |
| self, |
| root_cfg: DictConfig, |
| logger: Optional[WandbLogger] = None, |
| ckpt_path: Optional[Union[str, pathlib.Path]] = None, |
| ) -> None: |
| """ |
| Constructor |
| |
| Args: |
| cfg: configuration file that contains everything about the experiment |
| logger: a pytorch-lightning WandbLogger instance |
| ckpt_path: an optional path to saved checkpoint |
| """ |
| super().__init__() |
| self.root_cfg = root_cfg |
| self.cfg = root_cfg.experiment |
| self.debug = root_cfg.debug |
| self.logger = logger |
| self.ckpt_path = ckpt_path |
| self.algo = None |
| self.customized_load = getattr(root_cfg, "customized_load", False) |
| self.seperate_load = getattr(root_cfg, "seperate_load", False) |
| self.zero_init_gate= getattr(root_cfg, "zero_init_gate", False) |
| self.only_tune_memory = getattr(root_cfg, "only_tune_memory", False) |
| self.diffusion_model_path = getattr(root_cfg, "diffusion_model_path", None) |
| self.vae_path = getattr(root_cfg, "vae_path", None) |
| self.pose_predictor_path = getattr(root_cfg, "pose_predictor_path", None) |
| self.auto_resuming = getattr(root_cfg, "_auto_resuming", False) |
|
|
| def _build_algo(self): |
| """ |
| Build the lightning module |
| :return: a pytorch-lightning module to be launched |
| """ |
| algo_name = self.root_cfg.algorithm._name |
| if algo_name not in self.compatible_algorithms: |
| raise ValueError( |
| f"Algorithm {algo_name} not found in compatible_algorithms for this Experiment class. " |
| "Make sure you define compatible_algorithms correctly and make sure that each key has " |
| "same name as yaml file under '[project_root]/configurations/algorithm' without .yaml suffix" |
| ) |
| return self.compatible_algorithms[algo_name](self.root_cfg.algorithm) |
|
|
| def exec_task(self, task: str) -> None: |
| """ |
| Executing a certain task specified by string. Each task should be a stage of experiment. |
| In most computer vision / nlp applications, tasks should be just train and test. |
| In reinforcement learning, you might have more stages such as collecting dataset etc |
| |
| Args: |
| task: a string specifying a task implemented for this experiment |
| """ |
| if hasattr(self, task) and callable(getattr(self, task)): |
| if is_rank_zero: |
| print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}") |
| getattr(self, task)() |
| else: |
| raise ValueError( |
| f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable." |
| ) |
|
|
| def exec_interactive(self, task: str) -> None: |
| """ |
| Executing a certain task specified by string. Each task should be a stage of experiment. |
| In most computer vision / nlp applications, tasks should be just train and test. |
| In reinforcement learning, you might have more stages such as collecting dataset etc |
| |
| Args: |
| task: a string specifying a task implemented for this experiment |
| """ |
| if hasattr(self, task) and callable(getattr(self, task)): |
| if is_rank_zero: |
| print(cyan("Executing task:"), f"{task} out of {self.cfg.tasks}") |
| return getattr(self, task)() |
| else: |
| raise ValueError( |
| f"Specified task '{task}' not defined for class {self.__class__.__name__} or is not callable." |
| ) |
|
|
| class BaseLightningExperiment(BaseExperiment): |
| """ |
| Abstract class for pytorch lightning experiments. Useful for computer vision & nlp where main components are |
| simply models, datasets and train loop. |
| """ |
|
|
| |
| compatible_algorithms: Dict = NotImplementedError |
|
|
| |
| compatible_datasets: Dict = NotImplementedError |
|
|
| def _build_trainer_callbacks(self): |
| callbacks = [] |
| if self.logger: |
| callbacks.append(LearningRateMonitor("step", True)) |
|
|
| def _build_training_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]: |
| train_dataset = self._build_dataset("training") |
| shuffle = ( |
| False if isinstance(train_dataset, torch.utils.data.IterableDataset) else self.cfg.training.data.shuffle |
| ) |
| if train_dataset: |
| return torch.utils.data.DataLoader( |
| train_dataset, |
| batch_size=self.cfg.training.batch_size, |
| num_workers=min(os.cpu_count(), self.cfg.training.data.num_workers), |
| shuffle=shuffle, |
| persistent_workers=True, |
| pin_memory=torch.cuda.is_available(), |
| ) |
| else: |
| return None |
|
|
| def _build_validation_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]: |
| validation_dataset = self._build_dataset("validation") |
| shuffle = ( |
| False |
| if isinstance(validation_dataset, torch.utils.data.IterableDataset) |
| else self.cfg.validation.data.shuffle |
| ) |
| if validation_dataset: |
| return torch.utils.data.DataLoader( |
| validation_dataset, |
| batch_size=self.cfg.validation.batch_size, |
| num_workers=min(os.cpu_count(), self.cfg.validation.data.num_workers), |
| shuffle=shuffle, |
| persistent_workers=True, |
| pin_memory=torch.cuda.is_available(), |
| ) |
| else: |
| return None |
|
|
| def _build_test_loader(self) -> Optional[Union[TRAIN_DATALOADERS, pl.LightningDataModule]]: |
| test_dataset = self._build_dataset("test") |
| shuffle = False if isinstance(test_dataset, torch.utils.data.IterableDataset) else self.cfg.test.data.shuffle |
| if test_dataset: |
| return torch.utils.data.DataLoader( |
| test_dataset, |
| batch_size=self.cfg.test.batch_size, |
| num_workers=min(os.cpu_count(), self.cfg.test.data.num_workers), |
| shuffle=shuffle, |
| persistent_workers=True, |
| pin_memory=torch.cuda.is_available(), |
| ) |
| else: |
| return None |
|
|
| def training(self) -> None: |
| """ |
| All training happens here |
| """ |
| if not self.algo: |
| self.algo = self._build_algo() |
| if self.cfg.training.compile: |
| self.algo = torch.compile(self.algo) |
|
|
| callbacks = [] |
| if self.logger: |
| callbacks.append(LearningRateMonitor("step", True)) |
| if "checkpointing" in self.cfg.training: |
| callbacks.append( |
| ModelCheckpoint( |
| pathlib.Path(hydra.core.hydra_config.HydraConfig.get()["runtime"]["output_dir"]) / "checkpoints", |
| filename="epoch{epoch}_step{step}", |
| auto_insert_metric_name=False, |
| **self.cfg.training.checkpointing, |
| ) |
| ) |
|
|
| trainer = pl.Trainer( |
| accelerator="auto", |
| devices="auto", |
| strategy=DDPStrategy(find_unused_parameters=True) if torch.cuda.device_count() > 1 else "auto", |
| logger=self.logger or False, |
| callbacks=callbacks, |
| gradient_clip_val=self.cfg.training.optim.gradient_clip_val or 0.0, |
| val_check_interval=self.cfg.validation.val_every_n_step if self.cfg.validation.val_every_n_step else None, |
| limit_val_batches=self.cfg.validation.limit_batch, |
| check_val_every_n_epoch=self.cfg.validation.val_every_n_epoch if not self.cfg.validation.val_every_n_step else None, |
| accumulate_grad_batches=self.cfg.training.optim.accumulate_grad_batches or 1, |
| precision=self.cfg.training.precision or 32, |
| detect_anomaly=False, |
| num_sanity_val_steps=int(self.cfg.debug) if self.cfg.debug else 0, |
| max_epochs=self.cfg.training.max_epochs, |
| max_steps=self.cfg.training.max_steps, |
| max_time=self.cfg.training.max_time |
| ) |
|
|
|
|
| if self.auto_resuming: |
| self.algo._strict_resume_state = True |
| trainer.fit( |
| self.algo, |
| train_dataloaders=self._build_training_loader(), |
| val_dataloaders=self._build_validation_loader(), |
| ckpt_path=self.ckpt_path, |
| ) |
| elif self.customized_load: |
| if self.seperate_load: |
| if 'oasis500m' in self.diffusion_model_path: |
| load_custom_checkpoint( |
| algo=self.algo.diffusion_model.model, |
| checkpoint_path=self.diffusion_model_path, |
| require_match=True, |
| label="diffusion_model.model", |
| ) |
| else: |
| load_custom_checkpoint( |
| algo=self.algo.diffusion_model, |
| checkpoint_path=self.diffusion_model_path, |
| require_match=True, |
| label="diffusion_model", |
| ) |
| load_custom_checkpoint(algo=self.algo.vae, checkpoint_path=self.vae_path, require_match=True, label="vae") |
| else: |
| load_custom_checkpoint(algo=self.algo, checkpoint_path=self.ckpt_path, dememwm_key_check=True) |
|
|
| if self.zero_init_gate: |
| for name, para in self.algo.diffusion_model.named_parameters(): |
| if 'r_adaLN_modulation' in name: |
| para.requires_grad_(False) |
| para[2*1024:3*1024] = 0 |
| para[5*1024:6*1024] = 0 |
| para.requires_grad_(True) |
|
|
| if self.only_tune_memory: |
| for name, para in self.algo.diffusion_model.named_parameters(): |
| para.requires_grad_(False) |
| if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name: |
| para.requires_grad_(True) |
| |
| trainer.fit( |
| self.algo, |
| train_dataloaders=self._build_training_loader(), |
| val_dataloaders=self._build_validation_loader(), |
| ckpt_path=None, |
| ) |
| else: |
|
|
| if self.only_tune_memory: |
| for name, para in self.algo.diffusion_model.named_parameters(): |
| para.requires_grad_(False) |
| if 'r_' in name or 'pose_embedder' in name or 'pose_cond_mlp' in name or 'lora_' in name: |
| para.requires_grad_(True) |
| |
| trainer.fit( |
| self.algo, |
| train_dataloaders=self._build_training_loader(), |
| val_dataloaders=self._build_validation_loader(), |
| ckpt_path=self.ckpt_path, |
| ) |
|
|
| def validation(self) -> None: |
| """ |
| All validation happens here |
| """ |
| if not self.algo: |
| self.algo = self._build_algo() |
| if self.cfg.validation.compile: |
| self.algo = torch.compile(self.algo) |
|
|
| callbacks = [] |
|
|
| trainer = pl.Trainer( |
| accelerator="auto", |
| logger=self.logger, |
| devices="auto", |
| num_nodes=self.cfg.num_nodes, |
| strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto", |
| callbacks=callbacks, |
| limit_val_batches=self.cfg.validation.limit_batch, |
| precision=self.cfg.validation.precision, |
| detect_anomaly=False, |
| inference_mode=self.cfg.validation.inference_mode, |
| ) |
|
|
| if self.customized_load: |
| if self.seperate_load: |
| if 'oasis500m' in self.diffusion_model_path: |
| load_custom_checkpoint( |
| algo=self.algo.diffusion_model.model, |
| checkpoint_path=self.diffusion_model_path, |
| require_match=True, |
| label="diffusion_model.model", |
| report_key_mismatch=True, |
| ) |
| else: |
| load_custom_checkpoint( |
| algo=self.algo.diffusion_model, |
| checkpoint_path=self.diffusion_model_path, |
| require_match=True, |
| label="diffusion_model", |
| report_key_mismatch=True, |
| ) |
| load_custom_checkpoint( |
| algo=self.algo.vae, |
| checkpoint_path=self.vae_path, |
| require_match=True, |
| label="vae", |
| report_key_mismatch=True, |
| ) |
| else: |
| load_custom_checkpoint( |
| algo=self.algo, |
| checkpoint_path=self.ckpt_path, |
| label="algo", |
| dememwm_key_check=True, |
| report_key_mismatch=True, |
| ) |
|
|
| if self.zero_init_gate: |
| for name, para in self.algo.diffusion_model.named_parameters(): |
| if 'r_adaLN_modulation' in name: |
| para.requires_grad_(False) |
| para[2*1024:3*1024] = 0 |
| para[5*1024:6*1024] = 0 |
| para.requires_grad_(True) |
| |
| trainer.validate( |
| self.algo, |
| dataloaders=self._build_validation_loader(), |
| ckpt_path=None, |
| ) |
| else: |
| trainer.validate( |
| self.algo, |
| dataloaders=self._build_validation_loader(), |
| ckpt_path=self.ckpt_path, |
| ) |
|
|
| def test(self) -> None: |
| """ |
| All testing happens here |
| """ |
| if not self.algo: |
| self.algo = self._build_algo() |
| if self.cfg.test.compile: |
| self.algo = torch.compile(self.algo) |
|
|
| callbacks = [] |
|
|
| trainer = pl.Trainer( |
| accelerator="auto", |
| logger=self.logger, |
| devices="auto", |
| num_nodes=self.cfg.num_nodes, |
| strategy=DDPStrategy(find_unused_parameters=False) if torch.cuda.device_count() > 1 else "auto", |
| callbacks=callbacks, |
| limit_test_batches=self.cfg.test.limit_batch, |
| precision=self.cfg.test.precision, |
| detect_anomaly=False, |
| ) |
|
|
| if self.customized_load: |
| if self.seperate_load: |
| if 'oasis500m' in self.diffusion_model_path: |
| load_custom_checkpoint( |
| algo=self.algo.diffusion_model.model, |
| checkpoint_path=self.diffusion_model_path, |
| require_match=True, |
| label="diffusion_model.model", |
| report_key_mismatch=True, |
| ) |
| else: |
| load_custom_checkpoint( |
| algo=self.algo.diffusion_model, |
| checkpoint_path=self.diffusion_model_path, |
| require_match=True, |
| label="diffusion_model", |
| report_key_mismatch=True, |
| ) |
| load_custom_checkpoint( |
| algo=self.algo.vae, |
| checkpoint_path=self.vae_path, |
| require_match=True, |
| label="vae", |
| report_key_mismatch=True, |
| ) |
| else: |
| load_custom_checkpoint( |
| algo=self.algo, |
| checkpoint_path=self.ckpt_path, |
| label="algo", |
| dememwm_key_check=True, |
| report_key_mismatch=True, |
| ) |
|
|
| if self.zero_init_gate: |
| for name, para in self.algo.diffusion_model.named_parameters(): |
| if 'r_adaLN_modulation' in name: |
| para.requires_grad_(False) |
| para[2*1024:3*1024] = 0 |
| para[5*1024:6*1024] = 0 |
| para.requires_grad_(True) |
| |
| trainer.test( |
| self.algo, |
| dataloaders=self._build_test_loader(), |
| ckpt_path=None, |
| ) |
| else: |
| trainer.test( |
| self.algo, |
| dataloaders=self._build_test_loader(), |
| ckpt_path=self.ckpt_path, |
| ) |
|
|
| def _build_dataset(self, split: str) -> Optional[torch.utils.data.Dataset]: |
| if split in ["training", "test", "validation"]: |
| return self.compatible_datasets[self.root_cfg.dataset._name](self.root_cfg.dataset, split=split) |
| else: |
| raise NotImplementedError(f"split '{split}' is not implemented") |
|
|