File size: 4,170 Bytes
f8437ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Ablation config: T=64 diffusion steps
import os
import torch


def _get_env_int(name, default):
    value = os.environ.get(name)
    return int(value) if value is not None else default


def _get_env_str(name, default):
    return os.environ.get(name, default)


# πŸŽ›οΈ BASH-CONTROLLED SWITCHES (Defaults if run manually)
MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
TRAIN_DEVICE = _get_env_str(
    "TRAIN_DEVICE",
    "mps" if torch.backends.mps.is_available() else "cpu",
)

CONFIG = {
    "model_type": MODEL,

    "data": {
        "include_negative_examples": NEGATIVES,
        "dataset_size": 60000,
    },

    # "model": {
    #     "vocab_size": 16000,
    #     "max_seq_len": 80,
    #     "diffusion_steps": 10,
    #     "d_model": 384,
    #     "n_layers": 6,
    #     "n_heads": 6,
    #     "d_ff": 1536,
    #     "dropout": 0.15
    # },
    #
    # "diffusion": {
    #     "mask_token_id": 0
    # },
    #
    # "training": {
    #     "batch_size": 32,
    #     "epochs": 10,
    #     "lr": 2e-4,
    #     "label_smoothing": 0.05,
    #     "precision": "float32",
    #     "device": "mps" if torch.backends.mps.is_available() else "cpu",
    #     "early_stopping_patience": 3
    # }
# "model": {
#         "vocab_size": 16000,
#         "max_seq_len": 96,       # Optimized for GRETIL slokas
#         "diffusion_steps": 16,   # Use 16 steps (better than 8)
#         "d_model": 512,          # Wider model learns faster
#         "n_layers": 8,
#         "n_heads": 8,
#         "d_ff": 2048,
#         "dropout": 0.1
#     },
#
#     "diffusion": {
#         "mask_token_id": 0
#     },
#
#     "training": {
#         "batch_size": 32,
#         "epochs": 20,            # 20 is enough with these tweaks
#         "lr": 4e-4,              # Higher LR + Warmup for speed
#         "label_smoothing": 0.15, # Increased for 16k vocab stability
#         "precision": "float32",
#         "device": "mps" if torch.backends.mps.is_available() else "cpu",
#         "early_stopping_patience": 5
#     }
'diffusion': {
        'mask_token_id': 0,          # [MASK] = ID 0, fixed by tokenizer
    },

    # ── Model architecture ────────────────────────────────────────────
    'model': {
        # 'vocab_size':       16000,
'src_vocab_size': 16000,   # Roman/IAST BPE vocab
'tgt_vocab_size': 16000,   # Devanagari BPE vocab
        'd_model':          1024,#512,     # was 384 β€” kept same, shared embeds save params
        'n_heads':          8,       # 384 / 6 = 64 head_dim
        'd_ff':            4096, #2048, #1536,    # 4 Γ— d_model
        'n_layers':         8,#4,
        'dropout':          0.2,
        'max_seq_len':      80,
        'diffusion_steps':  DIFFUSION_STEPS,
    },

    # ── Training ──────────────────────────────────────────────────────
    'training': {
        'epochs':           20,       # Target: 0.71β†’0.83-0.85 in 5 epochs
        'batch_size':       32,
        'accum_steps':      2,       # effective batch = 64
        'lr':               7e-5,#6e-4,    # raised from 3e-4; warmup protects first steps
        'label_smoothing':  0.1,     # was 0.0; reduces overconfidence (gap 1.7 nats)
        'patience':         4,       # early stop after 4 non-improving epochs
        'l1_lambda':        1e-7,    # very light L1
        'device':           TRAIN_DEVICE,
    },

    # ── Inference (used during val BERTScore and generate()) ──────────
    'inference': {
        'num_steps':          INFERENCE_STEPS,
        'temperature':        0.7,   # slightly lower = more confident output
        'top_k':              40,
        'repetition_penalty': 1.2,
        'diversity_penalty':  0.5,   # keep off; global-mean penalty is conservative
    },
}