github-actions[bot]
Sync from GitHub 33c12db74322f3d28409b5dc0a8c441914c9178b
e0552b0
import dataclasses
from typing import Literal
@dataclasses.dataclass(frozen=True)
class KoreanConfig:
lr: float = 3e-4
steps: int = 1000
warmup_ratio: float = 0.15
latent_size: int = 16
mask_prob: float = 0.2
max_len: int = 10
@dataclasses.dataclass(frozen=True)
class JapaneseConfig:
lr: float = 3e-4
steps: int = 1000
warmup_ratio: float = 0.15 # 0.15
latent_size: int = 64 # 24
mask_prob: float = 0.2
max_len: int = 20
type Config = KoreanConfig | JapaneseConfig
def get_config(lang: Literal["ko", "ja"], config_dict: dict = {}):
if lang == "ko":
return KoreanConfig(**config_dict)
else:
return JapaneseConfig(**config_dict)
@dataclasses.dataclass(frozen=True)
class Stage2Config:
lr: float = 5e-4
steps: int = 300
weight_decay: float = 1e-2
warmup_ratio: float = 0.1
grad_clip_norm: float | None = 1.0
diffusion_timesteps: int = 1000
# cycle latent loss
cycle_start_epoch: int = 5
cycle_ramp_epochs: int = 10
max_lambda_cycle_latent: float = 0.1
# token cycle loss
token_cycle_start_epoch: int = 20
token_ko_scale: float = 0.04
token_ja_scale: float = 0.1
# direct domain MMD loss
domain_kj_scale: float = 0.08
domain_jk_scale: float = 0.09
# prior
prior_kj_scale: float = 1.0
prior_jk_scale: float = 1.0
# Cylce Start
cycle_start_timesteps: list[tuple[int, int]] = dataclasses.field(
default_factory=lambda: [
(0, 500),
]
)
repeat_penalty_ja_scale: float = 0.05
repeat_penalty_ko_scale: float = 0.02
sep_center_scale: float = 0.1
sep_count_scale: float = 0.1
sep_peak_scale: float = 0.05
use_gradient_checkpointing: bool = False
use_amp: bool = True