Spaces:
Sleeping
Sleeping
File size: 1,796 Bytes
e0552b0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 | import dataclasses
from typing import Literal
@dataclasses.dataclass(frozen=True)
class KoreanConfig:
lr: float = 3e-4
steps: int = 1000
warmup_ratio: float = 0.15
latent_size: int = 16
mask_prob: float = 0.2
max_len: int = 10
@dataclasses.dataclass(frozen=True)
class JapaneseConfig:
lr: float = 3e-4
steps: int = 1000
warmup_ratio: float = 0.15 # 0.15
latent_size: int = 64 # 24
mask_prob: float = 0.2
max_len: int = 20
type Config = KoreanConfig | JapaneseConfig
def get_config(lang: Literal["ko", "ja"], config_dict: dict = {}):
if lang == "ko":
return KoreanConfig(**config_dict)
else:
return JapaneseConfig(**config_dict)
@dataclasses.dataclass(frozen=True)
class Stage2Config:
lr: float = 5e-4
steps: int = 300
weight_decay: float = 1e-2
warmup_ratio: float = 0.1
grad_clip_norm: float | None = 1.0
diffusion_timesteps: int = 1000
# cycle latent loss
cycle_start_epoch: int = 5
cycle_ramp_epochs: int = 10
max_lambda_cycle_latent: float = 0.1
# token cycle loss
token_cycle_start_epoch: int = 20
token_ko_scale: float = 0.04
token_ja_scale: float = 0.1
# direct domain MMD loss
domain_kj_scale: float = 0.08
domain_jk_scale: float = 0.09
# prior
prior_kj_scale: float = 1.0
prior_jk_scale: float = 1.0
# Cylce Start
cycle_start_timesteps: list[tuple[int, int]] = dataclasses.field(
default_factory=lambda: [
(0, 500),
]
)
repeat_penalty_ja_scale: float = 0.05
repeat_penalty_ko_scale: float = 0.02
sep_center_scale: float = 0.1
sep_count_scale: float = 0.1
sep_peak_scale: float = 0.05
use_gradient_checkpointing: bool = False
use_amp: bool = True
|