File size: 4,170 Bytes
f8437ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | # Ablation config: T=32 diffusion steps
import os
import torch
def _get_env_int(name, default):
value = os.environ.get(name)
return int(value) if value is not None else default
def _get_env_str(name, default):
return os.environ.get(name, default)
# ποΈ BASH-CONTROLLED SWITCHES (Defaults if run manually)
MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
TRAIN_DEVICE = _get_env_str(
"TRAIN_DEVICE",
"mps" if torch.backends.mps.is_available() else "cpu",
)
CONFIG = {
"model_type": MODEL,
"data": {
"include_negative_examples": NEGATIVES,
"dataset_size": 60000,
},
# "model": {
# "vocab_size": 16000,
# "max_seq_len": 80,
# "diffusion_steps": 10,
# "d_model": 384,
# "n_layers": 6,
# "n_heads": 6,
# "d_ff": 1536,
# "dropout": 0.15
# },
#
# "diffusion": {
# "mask_token_id": 0
# },
#
# "training": {
# "batch_size": 32,
# "epochs": 10,
# "lr": 2e-4,
# "label_smoothing": 0.05,
# "precision": "float32",
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
# "early_stopping_patience": 3
# }
# "model": {
# "vocab_size": 16000,
# "max_seq_len": 96, # Optimized for GRETIL slokas
# "diffusion_steps": 16, # Use 16 steps (better than 8)
# "d_model": 512, # Wider model learns faster
# "n_layers": 8,
# "n_heads": 8,
# "d_ff": 2048,
# "dropout": 0.1
# },
#
# "diffusion": {
# "mask_token_id": 0
# },
#
# "training": {
# "batch_size": 32,
# "epochs": 20, # 20 is enough with these tweaks
# "lr": 4e-4, # Higher LR + Warmup for speed
# "label_smoothing": 0.15, # Increased for 16k vocab stability
# "precision": "float32",
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
# "early_stopping_patience": 5
# }
'diffusion': {
'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
},
# ββ Model architecture ββββββββββββββββββββββββββββββββββββββββββββ
'model': {
# 'vocab_size': 16000,
'src_vocab_size': 16000, # Roman/IAST BPE vocab
'tgt_vocab_size': 16000, # Devanagari BPE vocab
'd_model': 1024,#512, # was 384 β kept same, shared embeds save params
'n_heads': 8, # 384 / 6 = 64 head_dim
'd_ff': 4096, #2048, #1536, # 4 Γ d_model
'n_layers': 8,#4,
'dropout': 0.2,
'max_seq_len': 80,
'diffusion_steps': DIFFUSION_STEPS,
},
# ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
'training': {
'epochs': 20, # Target: 0.71β0.83-0.85 in 5 epochs
'batch_size': 32,
'accum_steps': 2, # effective batch = 64
'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
'patience': 4, # early stop after 4 non-improving epochs
'l1_lambda': 1e-7, # very light L1
'device': TRAIN_DEVICE,
},
# ββ Inference (used during val BERTScore and generate()) ββββββββββ
'inference': {
'num_steps': INFERENCE_STEPS,
'temperature': 0.7, # slightly lower = more confident output
'top_k': 40,
'repetition_penalty': 1.2,
'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
},
}
|