devflow / config_T8.py
bhsinghgrid's picture
Upload 27 files
f8437ec verified
# Ablation config: T=8 diffusion steps
import os
import torch
def _get_env_int(name, default):
value = os.environ.get(name)
return int(value) if value is not None else default
def _get_env_str(name, default):
return os.environ.get(name, default)
# πŸŽ›οΈ BASH-CONTROLLED SWITCHES (Defaults if run manually)
MODEL = os.environ.get("MODEL_TYPE", "d3pm_encoder_decoder")
NEGATIVES = os.environ.get("INCLUDE_NEG", "False") == "True"
DIFFUSION_STEPS = _get_env_int("DIFFUSION_STEPS", 128)
INFERENCE_STEPS = _get_env_int("INFERENCE_NUM_STEPS", min(64, DIFFUSION_STEPS))
TRAIN_DEVICE = _get_env_str(
"TRAIN_DEVICE",
"mps" if torch.backends.mps.is_available() else "cpu",
)
CONFIG = {
"model_type": MODEL,
"data": {
"include_negative_examples": NEGATIVES,
"dataset_size": 60000,
},
# "model": {
# "vocab_size": 16000,
# "max_seq_len": 80,
# "diffusion_steps": 10,
# "d_model": 384,
# "n_layers": 6,
# "n_heads": 6,
# "d_ff": 1536,
# "dropout": 0.15
# },
#
# "diffusion": {
# "mask_token_id": 0
# },
#
# "training": {
# "batch_size": 32,
# "epochs": 10,
# "lr": 2e-4,
# "label_smoothing": 0.05,
# "precision": "float32",
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
# "early_stopping_patience": 3
# }
# "model": {
# "vocab_size": 16000,
# "max_seq_len": 96, # Optimized for GRETIL slokas
# "diffusion_steps": 16, # Use 16 steps (better than 8)
# "d_model": 512, # Wider model learns faster
# "n_layers": 8,
# "n_heads": 8,
# "d_ff": 2048,
# "dropout": 0.1
# },
#
# "diffusion": {
# "mask_token_id": 0
# },
#
# "training": {
# "batch_size": 32,
# "epochs": 20, # 20 is enough with these tweaks
# "lr": 4e-4, # Higher LR + Warmup for speed
# "label_smoothing": 0.15, # Increased for 16k vocab stability
# "precision": "float32",
# "device": "mps" if torch.backends.mps.is_available() else "cpu",
# "early_stopping_patience": 5
# }
'diffusion': {
'mask_token_id': 0, # [MASK] = ID 0, fixed by tokenizer
},
# ── Model architecture ────────────────────────────────────────────
'model': {
# 'vocab_size': 16000,
'src_vocab_size': 16000, # Roman/IAST BPE vocab
'tgt_vocab_size': 16000, # Devanagari BPE vocab
'd_model': 1024,#512, # was 384 β€” kept same, shared embeds save params
'n_heads': 8, # 384 / 6 = 64 head_dim
'd_ff': 4096, #2048, #1536, # 4 Γ— d_model
'n_layers': 8,#4,
'dropout': 0.2,
'max_seq_len': 80,
'diffusion_steps': DIFFUSION_STEPS,
},
# ── Training ──────────────────────────────────────────────────────
'training': {
'epochs': 20, # Target: 0.71β†’0.83-0.85 in 5 epochs
'batch_size': 32,
'accum_steps': 2, # effective batch = 64
'lr': 7e-5,#6e-4, # raised from 3e-4; warmup protects first steps
'label_smoothing': 0.1, # was 0.0; reduces overconfidence (gap 1.7 nats)
'patience': 4, # early stop after 4 non-improving epochs
'l1_lambda': 1e-7, # very light L1
'device': TRAIN_DEVICE,
},
# ── Inference (used during val BERTScore and generate()) ──────────
'inference': {
'num_steps': INFERENCE_STEPS,
'temperature': 0.7, # slightly lower = more confident output
'top_k': 40,
'repetition_penalty': 1.2,
'diversity_penalty': 0.5, # keep off; global-mean penalty is conservative
},
}