from typing import Tuple, List from transformers import PretrainedConfig class PLPQConfig(PretrainedConfig): model_type: str = "PLPQ" def __init__(self, image_size: List[int] = [512, 512], patch_size: int = 16, dropout: float = 0.0, levels: List[int] = [8,8,8,5,5,5], num_quantizers: int = 4, num_in_channels: int = 3, num_out_channels: int = 3, use_wavelets: bool = True, encoder_blocks: List[List] = [], decoder_blocks: List[List] = [], **kwargs ): self.image_size = image_size self.patch_size = patch_size self.dropout = dropout self.levels = levels self.num_quantizers = num_quantizers self.num_in_channels = num_in_channels self.num_out_channels = num_out_channels self.use_wavelets = use_wavelets self.encoder_blocks = encoder_blocks self.decoder_blocks = decoder_blocks super().__init__(**kwargs)