| | import torch |
| | from accelerate import Accelerator |
| | from accelerate.logging import MultiProcessAdapter |
| | from dataclasses import dataclass, field |
| | from typing import Optional, Union |
| | from datasets import load_dataset |
| | import json |
| | import abc |
| | from diffusers.utils import make_image_grid |
| | import numpy as np |
| | import wandb |
| |
|
| | from custum_3d_diffusion.trainings.utils import load_config |
| | from custum_3d_diffusion.custum_modules.unifield_processor import ConfigurableUNet2DConditionModel, AttnConfig |
| |
|
| | class BasicTrainer(torch.nn.Module, abc.ABC): |
| | accelerator: Accelerator |
| | logger: MultiProcessAdapter |
| | unet: ConfigurableUNet2DConditionModel |
| | train_dataloader: torch.utils.data.DataLoader |
| | test_dataset: torch.utils.data.Dataset |
| | attn_config: AttnConfig |
| | |
| | @dataclass |
| | class TrainerConfig: |
| | trainer_name: str = "basic" |
| | pretrained_model_name_or_path: str = "" |
| | |
| | attn_config: dict = field(default_factory=dict) |
| | dataset_name: str = "" |
| | dataset_config_name: Optional[str] = None |
| | resolution: str = "1024" |
| | dataloader_num_workers: int = 4 |
| | pair_sampler_group_size: int = 1 |
| | num_views: int = 4 |
| | |
| | max_train_steps: int = -1 |
| | training_step_interval: int = 1 |
| | max_train_samples: Optional[int] = None |
| | seed: Optional[int] = None |
| | train_batch_size: int = 1 |
| | |
| | validation_interval: int = 5000 |
| | debug: bool = False |
| | |
| | cfg: TrainerConfig |
| | |
| | def __init__( |
| | self, |
| | accelerator: Accelerator, |
| | logger: MultiProcessAdapter, |
| | unet: ConfigurableUNet2DConditionModel, |
| | config: Union[dict, str], |
| | weight_dtype: torch.dtype, |
| | index: int, |
| | ): |
| | super().__init__() |
| | self.index = index |
| | self.accelerator = accelerator |
| | self.logger = logger |
| | self.unet = unet |
| | self.weight_dtype = weight_dtype |
| | self.ext_logs = {} |
| | self.cfg = load_config(self.TrainerConfig, config) |
| | self.attn_config = load_config(AttnConfig, self.cfg.attn_config) |
| | self.test_dataset = None |
| | self.validate_trainer_config() |
| | self.configure() |
| | |
| | def get_HW(self): |
| | resolution = json.loads(self.cfg.resolution) |
| | if isinstance(resolution, int): |
| | H = W = resolution |
| | elif isinstance(resolution, list): |
| | H, W = resolution |
| | return H, W |
| | |
| | def unet_update(self): |
| | self.unet.update_config(self.attn_config) |
| | |
| | def validate_trainer_config(self): |
| | pass |
| | |
| | def is_train_finished(self, current_step): |
| | assert isinstance(self.cfg.max_train_steps, int) |
| | return self.cfg.max_train_steps != -1 and current_step >= self.cfg.max_train_steps |
| | |
| | def next_train_step(self, current_step): |
| | if self.is_train_finished(current_step): |
| | return None |
| | return current_step + self.cfg.training_step_interval |
| |
|
| | @classmethod |
| | def make_image_into_grid(cls, all_imgs, rows=2, columns=2): |
| | catted = [make_image_grid(all_imgs[i:i+rows * columns], rows=rows, cols=columns) for i in range(0, len(all_imgs), rows * columns)] |
| | return make_image_grid(catted, rows=1, cols=len(catted)) |
| |
|
| | def configure(self) -> None: |
| | pass |
| | |
| | @abc.abstractmethod |
| | def init_shared_modules(self, shared_modules: dict) -> dict: |
| | pass |
| | |
| | def load_dataset(self): |
| | dataset = load_dataset( |
| | self.cfg.dataset_name, |
| | self.cfg.dataset_config_name, |
| | trust_remote_code=True |
| | ) |
| | return dataset |
| |
|
| | @abc.abstractmethod |
| | def init_train_dataloader(self, shared_modules: dict) -> torch.utils.data.DataLoader: |
| | """Both init train_dataloader and test_dataset, but returns train_dataloader only""" |
| | pass |
| | |
| | @abc.abstractmethod |
| | def forward_step( |
| | self, |
| | *args, |
| | **kwargs |
| | ) -> torch.Tensor: |
| | """ |
| | input a batch |
| | return a loss |
| | """ |
| | self.unet_update() |
| | pass |
| | |
| | @abc.abstractmethod |
| | def construct_pipeline(self, shared_modules, unet): |
| | pass |
| | |
| | @abc.abstractmethod |
| | def pipeline_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: |
| | """ |
| | For inference time forward. |
| | """ |
| | pass |
| |
|
| | @abc.abstractmethod |
| | def batched_validation_forward(self, pipeline, **pipeline_call_kwargs) -> tuple: |
| | pass |
| |
|
| | def do_validation( |
| | self, |
| | shared_modules, |
| | unet, |
| | global_step, |
| | ): |
| | self.unet_update() |
| | self.logger.info("Running validation... ") |
| | pipeline = self.construct_pipeline(shared_modules, unet) |
| | pipeline.set_progress_bar_config(disable=True) |
| | titles, images = self.batched_validation_forward(pipeline, guidance_scale=[1., 3.]) |
| | for tracker in self.accelerator.trackers: |
| | if tracker.name == "tensorboard": |
| | np_images = np.stack([np.asarray(img) for img in images]) |
| | tracker.writer.add_images("validation", np_images, global_step, dataformats="NHWC") |
| | elif tracker.name == "wandb": |
| | [image.thumbnail((512, 512)) for image, title in zip(images, titles) if 'noresize' not in title] |
| | tracker.log({"validation": [ |
| | wandb.Image(image, caption=f"{i}: {titles[i]}", file_type="jpg") |
| | for i, image in enumerate(images)]}) |
| | else: |
| | self.logger.warn(f"image logging not implemented for {tracker.name}") |
| | del pipeline |
| | torch.cuda.empty_cache() |
| | return images |
| |
|
| | |
| | @torch.no_grad() |
| | def log_validation( |
| | self, |
| | shared_modules, |
| | unet, |
| | global_step, |
| | force=False |
| | ): |
| | if self.accelerator.is_main_process: |
| | for tracker in self.accelerator.trackers: |
| | if tracker.name == "wandb": |
| | tracker.log(self.ext_logs) |
| | self.ext_logs = {} |
| | if (global_step % self.cfg.validation_interval == 0 and not self.is_train_finished(global_step)) or force: |
| | self.unet_update() |
| | if self.accelerator.is_main_process: |
| | self.do_validation(shared_modules, self.accelerator.unwrap_model(unet), global_step) |
| |
|
| | def save_model(self, unwrap_unet, shared_modules, save_dir): |
| | if self.accelerator.is_main_process: |
| | pipeline = self.construct_pipeline(shared_modules, unwrap_unet) |
| | pipeline.save_pretrained(save_dir) |
| | self.logger.info(f"{self.cfg.trainer_name} Model saved at {save_dir}") |
| |
|
| | def save_debug_info(self, save_name="debug", **kwargs): |
| | if self.cfg.debug: |
| | to_saves = {key: value.detach().cpu() if isinstance(value, torch.Tensor) else value for key, value in kwargs.items()} |
| | import pickle |
| | import os |
| | if os.path.exists(f"{save_name}.pkl"): |
| | for i in range(100): |
| | if not os.path.exists(f"{save_name}_v{i}.pkl"): |
| | save_name = f"{save_name}_v{i}" |
| | break |
| | with open(f"{save_name}.pkl", "wb") as f: |
| | pickle.dump(to_saves, f) |