Spaces:
Sleeping
Sleeping
| import dataclasses | |
| from typing import Literal | |
| class KoreanConfig: | |
| lr: float = 3e-4 | |
| steps: int = 1000 | |
| warmup_ratio: float = 0.15 | |
| latent_size: int = 16 | |
| mask_prob: float = 0.2 | |
| max_len: int = 10 | |
| class JapaneseConfig: | |
| lr: float = 3e-4 | |
| steps: int = 1000 | |
| warmup_ratio: float = 0.15 # 0.15 | |
| latent_size: int = 64 # 24 | |
| mask_prob: float = 0.2 | |
| max_len: int = 20 | |
| type Config = KoreanConfig | JapaneseConfig | |
| def get_config(lang: Literal["ko", "ja"], config_dict: dict = {}): | |
| if lang == "ko": | |
| return KoreanConfig(**config_dict) | |
| else: | |
| return JapaneseConfig(**config_dict) | |
| class Stage2Config: | |
| lr: float = 5e-4 | |
| steps: int = 300 | |
| weight_decay: float = 1e-2 | |
| warmup_ratio: float = 0.1 | |
| grad_clip_norm: float | None = 1.0 | |
| diffusion_timesteps: int = 1000 | |
| # cycle latent loss | |
| cycle_start_epoch: int = 5 | |
| cycle_ramp_epochs: int = 10 | |
| max_lambda_cycle_latent: float = 0.1 | |
| # token cycle loss | |
| token_cycle_start_epoch: int = 20 | |
| token_ko_scale: float = 0.04 | |
| token_ja_scale: float = 0.1 | |
| # direct domain MMD loss | |
| domain_kj_scale: float = 0.08 | |
| domain_jk_scale: float = 0.09 | |
| # prior | |
| prior_kj_scale: float = 1.0 | |
| prior_jk_scale: float = 1.0 | |
| # Cylce Start | |
| cycle_start_timesteps: list[tuple[int, int]] = dataclasses.field( | |
| default_factory=lambda: [ | |
| (0, 500), | |
| ] | |
| ) | |
| repeat_penalty_ja_scale: float = 0.05 | |
| repeat_penalty_ko_scale: float = 0.02 | |
| sep_center_scale: float = 0.1 | |
| sep_count_scale: float = 0.1 | |
| sep_peak_scale: float = 0.05 | |
| use_gradient_checkpointing: bool = False | |
| use_amp: bool = True | |