Spaces:
Runtime error
Runtime error
| # ------------------------------------------------------------------------------------ | |
| # minDALL-E | |
| # Copyright (c) 2021 Kakao Brain Corp. All Rights Reserved. | |
| # Licensed under the Apache License, Version 2.0 [see LICENSE for details] | |
| # ------------------------------------------------------------------------------------ | |
| from typing import Optional, List | |
| from dataclasses import dataclass, field | |
| from omegaconf import OmegaConf | |
| class DataConfig: | |
| dataset: Optional[str] = None | |
| tokenizer_type: str = 'CharBPE' | |
| context_length: int = 64 | |
| image_resolution: int = 256 | |
| transforms: str = 'dalle-vqvae' | |
| bpe_pdrop: Optional[float] = None | |
| class Stage1Hparams: | |
| double_z: bool = False | |
| z_channels: int = 256 | |
| resolution: int = 256 | |
| in_channels: int = 3 | |
| out_ch: int = 3 | |
| ch: int = 128 | |
| ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4]) | |
| num_res_blocks: int = 2 | |
| attn_resolutions: List[int] = field(default_factory=lambda: [16]) | |
| pdrop: float = 0.0 | |
| class Stage2Hparams: | |
| embed_dim: int = 1536 | |
| n_layers: int = 42 | |
| n_heads: int = 24 | |
| n_dense_layers: int = 42 | |
| ctx_len_img: int = 256 | |
| ctx_len_txt: int = 64 | |
| embd_pdrop: float = 0.0 | |
| resid_pdrop: float = 0.0 | |
| attn_pdrop: float = 0.0 | |
| mlp_bias: bool = True | |
| attn_bias: bool = True | |
| gelu_use_approx: bool = False | |
| use_head_txt: bool = True | |
| n_classes: Optional[int] = None | |
| class Stage1Config: | |
| type: str = 'vqgan' | |
| embed_dim: int = 256 | |
| n_embed: int = 16384 | |
| hparams: Stage1Hparams = Stage1Hparams() | |
| class Stage2Config: | |
| type: str = 'transformer1d' | |
| vocab_size_txt: int = 16384 | |
| vocab_size_img: int = 16384 | |
| use_cls_cond: Optional[bool] = None | |
| hparams: Stage2Hparams = Stage2Hparams() | |
| class WarmupConfig: | |
| epoch: int = 1 | |
| multiplier: int = 1 | |
| buffer_epoch: int = 0 | |
| min_lr: float = 0.0 | |
| mode: str = 'fix' | |
| peak_lr: float = 1e-4 | |
| start_from_zero: bool = True | |
| class OptConfig: | |
| opt_type: str = 'adamW' | |
| base_lr: float = 1e-4 | |
| weight_decay: float = 1e-4 | |
| betas: List[float] = field(default_factory=lambda: [0.9, 0.99]) | |
| grad_clip_norm: float = 1.0 | |
| sched_type: str = 'cosine' | |
| max_steps: int = 0 | |
| min_lr: float = 0.0 | |
| class ExpConfig: | |
| local_batch_size: int = 4 | |
| total_batch_size: int = 512 | |
| valid_batch_size: int = 32 | |
| epochs: int = 10 | |
| save_ckpt_freq: int = 2 | |
| test_freq: int = 1 | |
| use_amp: bool = True | |
| class DefaultConfig: | |
| dataset: DataConfig = DataConfig() | |
| stage1: Stage1Config = Stage1Config() | |
| stage2: Stage2Config = Stage2Config() | |
| class FineTuningConfig: | |
| dataset: DataConfig = DataConfig() | |
| stage1: Stage1Config = Stage1Config() | |
| stage2: Stage2Config = Stage2Config() | |
| optimizer: OptConfig = OptConfig() | |
| experiment: ExpConfig = ExpConfig() | |
| def get_base_config(use_default=True): | |
| return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig) | |