| from typing import Tuple, List | |
| from transformers import PretrainedConfig | |
| class PLPQConfig(PretrainedConfig): | |
| model_type: str = "PLPQ" | |
| def __init__(self, | |
| image_size: List[int] = [512, 512], | |
| patch_size: int = 16, | |
| dropout: float = 0.0, | |
| levels: List[int] = [8,8,8,5,5,5], | |
| num_quantizers: int = 4, | |
| num_in_channels: int = 3, | |
| num_out_channels: int = 3, | |
| use_wavelets: bool = True, | |
| encoder_blocks: List[List] = [], | |
| decoder_blocks: List[List] = [], | |
| **kwargs | |
| ): | |
| self.image_size = image_size | |
| self.patch_size = patch_size | |
| self.dropout = dropout | |
| self.levels = levels | |
| self.num_quantizers = num_quantizers | |
| self.num_in_channels = num_in_channels | |
| self.num_out_channels = num_out_channels | |
| self.use_wavelets = use_wavelets | |
| self.encoder_blocks = encoder_blocks | |
| self.decoder_blocks = decoder_blocks | |
| super().__init__(**kwargs) | |