File size: 1,796 Bytes
e0552b0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import dataclasses
from typing import Literal


@dataclasses.dataclass(frozen=True)
class KoreanConfig:
    lr: float = 3e-4
    steps: int = 1000
    warmup_ratio: float = 0.15
    latent_size: int = 16
    mask_prob: float = 0.2
    max_len: int = 10


@dataclasses.dataclass(frozen=True)
class JapaneseConfig:
    lr: float = 3e-4
    steps: int = 1000
    warmup_ratio: float = 0.15  # 0.15
    latent_size: int = 64  # 24
    mask_prob: float = 0.2
    max_len: int = 20


type Config = KoreanConfig | JapaneseConfig


def get_config(lang: Literal["ko", "ja"], config_dict: dict = {}):
    if lang == "ko":
        return KoreanConfig(**config_dict)
    else:
        return JapaneseConfig(**config_dict)


@dataclasses.dataclass(frozen=True)
class Stage2Config:
    lr: float = 5e-4
    steps: int = 300
    weight_decay: float = 1e-2
    warmup_ratio: float = 0.1
    grad_clip_norm: float | None = 1.0

    diffusion_timesteps: int = 1000

    # cycle latent loss
    cycle_start_epoch: int = 5
    cycle_ramp_epochs: int = 10
    max_lambda_cycle_latent: float = 0.1

    # token cycle loss
    token_cycle_start_epoch: int = 20
    token_ko_scale: float = 0.04
    token_ja_scale: float = 0.1

    # direct domain MMD loss
    domain_kj_scale: float = 0.08
    domain_jk_scale: float = 0.09

    # prior
    prior_kj_scale: float = 1.0
    prior_jk_scale: float = 1.0

    # Cylce Start
    cycle_start_timesteps: list[tuple[int, int]] = dataclasses.field(
        default_factory=lambda: [
            (0, 500),
        ]
    )

    repeat_penalty_ja_scale: float = 0.05
    repeat_penalty_ko_scale: float = 0.02

    sep_center_scale: float = 0.1
    sep_count_scale: float = 0.1
    sep_peak_scale: float = 0.05

    use_gradient_checkpointing: bool = False
    use_amp: bool = True