| from transformers.configuration_utils import PretrainedConfig |
| from transformers.utils import logging |
|
|
| logger = logging.get_logger(__name__) |
|
|
| class EmptyClass(PretrainedConfig): |
| def __init__(self): |
| pass |
| class SDConfig(PretrainedConfig): |
|
|
| def __init__(self, |
| override_total_steps = -1, |
| freeze_vae = True, |
| use_flash = False, |
| adapt_topk = -1, |
| loss = 'mse', |
| mean = [0.485, 0.456, 0.406], |
| std = [0.229, 0.224, 0.225], |
| use_same_noise_among_timesteps = False, |
| random_timestep_per_iteration = True, |
| rand_timestep_equal_int = False, |
| output_dir = './outputs/First_Start', |
| do_center_crop_size = 384, |
| architectures = None, |
| input = None, |
| model = None, |
| tta = None, |
| **kwargs |
| |
| |
| ): |
| super().__init__() |
| self.model = EmptyClass() |
| self.model.override_total_steps = override_total_steps |
| self.model.freeze_vae = freeze_vae |
| self.model.use_flash = use_flash |
| self.tta = EmptyClass() |
| self.tta.gradient_descent = EmptyClass() |
| self.tta.adapt_topk = adapt_topk |
| self.tta.loss = loss |
| self.tta.use_same_noise_among_timesteps = use_same_noise_among_timesteps |
| self.tta.random_timestep_per_iteration = random_timestep_per_iteration |
| self.tta.rand_timestep_equal_int = rand_timestep_equal_int |
| self.input = EmptyClass() |
| self.input.mean = mean |
| self.input.std = std |
| self.output_dir = output_dir |
| self.do_center_crop_size = do_center_crop_size |
| self.architectures = architectures |
| for k, v in kwargs.items(): |
| setattr(self, k, v) |
| if __name__ =='__main__': |
| SDConfig() |
|
|