| # Ablation config: T=4 diffusion steps | |
| import os | |
| import torch | |
| def _get_env_int(name, default): | |
| value = os.environ.get(name) | |
| return int(value) if value is not None else default | |
| def _get_env_str(name, default): | |
| return os.environ.get(name, default) | |
| # ποΈ BASH-CONTROLLED SWITCHES (Defaults if run manually) | |
| MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder") | |
| NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True" | |
| DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128) | |
| INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS)) | |
| TRAIN_DEVICE = _get_env_str( | |
| "TRAIN_DEVICE", | |
| "mps" if torch.backends.mps.is_available() else "cpu", | |
| ) | |
| CONFIG = { | |
| "model_type": MODEL, | |
| "data": { | |
| "include_negative_examples": NEGATIVES, | |
| "dataset_size": 60000, | |
| }, | |
| # "model": { | |
| # "vocab_size": 16000, | |
| # "max_seq_len": 80, | |
| # "diffusion_steps": 10, | |
| # "d_model": 384, | |
| # "n_layers": 6, | |
| # "n_heads": 6, | |
| # "d_ff": 1536, | |
| # "dropout": 0.15 | |
| # }, | |
| # | |
| # "diffusion": { | |
| # "mask_token_id": 0 | |
| # }, | |
| # | |
| # "training": { | |
| # "batch_size": 32, | |
| # "epochs": 10, | |
| # "lr": 2e-4, | |
| # "label_smoothing": 0.05, | |
| # "precision": "float32", | |
| # "device": "mps" if torch.backends.mps.is_available() else "cpu", | |
| # "early_stopping_patience": 3 | |
| # } | |
| # "model": { | |
| # "vocab_size": 16000, | |
| # "max_seq_len": 96, # Optimized for GRETIL slokas | |
| # "diffusion_steps": 16, # Use 16 steps (better than 8) | |
| # "d_model": 512, # Wider model learns faster | |
| # "n_layers": 8, | |
| # "n_heads": 8, | |
| # "d_ff": 2048, | |
| # "dropout": 0.1 | |
| # }, | |
| # | |
| # "diffusion": { | |
| # "mask_token_id": 0 | |
| # }, | |
| # | |
| # "training": { | |
| # "batch_size": 32, | |
| # "epochs": 20, # 20 is enough with these tweaks | |
| # "lr": 4e-4, # Higher LR + Warmup for speed | |
| # "label_smoothing": 0.15, # Increased for 16k vocab stability | |
| # "precision": "float32", | |
| # "device": "mps" if torch.backends.mps.is_available() else "cpu", | |
| # "early_stopping_patience": 5 | |
| # } | |
| 'diffusion': { | |
| 'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer | |
| }, | |
| # ββ Model architecture ββββββββββββββββββββββββββββββββββββββββββββ | |
| 'model': { | |
| # 'vocab_size': 16000, | |
| 'src_vocab_size': 16000, # Roman/IAST BPE vocab | |
| 'tgt_vocab_size': 16000, # Devanagari BPE vocab | |
| 'd_model': 1024,#512, # was 384 β kept same, shared embeds save params | |
| 'n_heads': 8, # 384 / 6 = 64 head_dim | |
| 'd_ff': 4096, #2048, #1536, # 4 Γ d_model | |
| 'n_layers': 8,#4, | |
| 'dropout': 0.2, | |
| 'max_seq_len': 80, | |
| 'diffusion_steps': DIFFUSION_STEPS, | |
| }, | |
| # ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| 'training': { | |
| 'epochs': 20, # Target: 0.71β0.83-0.85 in 5 epochs | |
| 'batch_size': 32, | |
| 'accum_steps': 2, # effective batch = 64 | |
| 'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps | |
| 'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats) | |
| 'patience': 4, # early stop after 4 non-improving epochs | |
| 'l1_lambda': 1e-7, # very light L1 | |
| 'device': TRAIN_DEVICE, | |
| }, | |
| # ββ Inference (used during val BERTScore and generate()) ββββββββββ | |
| 'inference': { | |
| 'num_steps': INFERENCE_STEPS, | |
| 'temperature': 0.7, # slightly lower = more confident output | |
| 'top_k': 40, | |
| 'repetition_penalty': 1.2, | |
| 'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative | |
| }, | |
| } | |