File size: 1,775 Bytes
e27ab6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from pathlib import Path
import torch

# Path Configuration
DATA_PATH = Path(r"data\IWSLT-15-en-vi")

# TOKENIZER_NAME = ""
# TOKENIZER_NAME = "iwslt_en-vi_tokenizer_16k.json"
TOKENIZER_NAME = "iwslt_en-vi_tokenizer_32k.json"
TOKENIZER_PATH = Path(r"artifacts\tokenizers") / TOKENIZER_NAME

MODEL_DIR = Path(r"artifacts\models")

# MODEL_NAME = ""
# MODEL_NAME = "transformer_en_vi_iwslt_1.pt"
MODEL_NAME = "transformer_en_vi_iwslt_1.safetensors"

# MODEL_SAVE_PATH = MODEL_DIR / MODEL_NAME
MODEL_SAVE_PATH = MODEL_DIR / "transformer_en_vi_iwslt_kaggle_1.safetensors"
# MODEL_SAVE_PATH = Path(r"notebooks\models") / MODEL_NAME

CHECKPOINT_PATH = Path(r"artifacts\checkpoints") / MODEL_NAME

CACHE_DIR = ""


# Hardware & Data Config
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

NUM_WORKERS: int = 4

VOCAB_SIZE: int = 32_000

SPECIAL_TOKENS: list[str] = ["[PAD]", "[UNK]", "[SOS]", "[EOS]"]

NUM_SAMPLES_TO_USE: int = 1000
# NUM_SAMPLES_TO_USE: int = 1_000_000


# Tokenizer Constants
PAD_TOKEN_ID: int = 0
UNK_TOKEN_ID: int = 1
SOS_TOKEN_ID: int = 2
EOS_TOKEN_ID: int = 3


# Model Hyperparameters
# D_MODEL: int = 256  # (Dimension of model)
D_MODEL: int = 512
N_LAYERS: int = 6  # (N=6 in paper)
N_HEADS: int = 8  # (h=8 in paper)
# D_FF: int = 1024  # (d_ff = 4 * d_model = 1024)
D_FF: int = 2048
DROPOUT: float = 0.1  # (Dropout = 0.1 in paper)
MAX_SEQ_LEN: int = 150  # (Max length for Positional Encoding)


# Training Configuration
# LEARNING_RATE: float = 1e-4
LEARNING_RATE: float = 5e-4
BATCH_SIZE: int = 32
EPOCHS: int = 5
# EPOCHS: int = 50

# HuggingFace
REPO_ID: str = "AlainDeLong/transformer-en-vi-base"
FILENAME: str = "transformer_en_vi_iwslt_kaggle_1.safetensors"

if __name__ == "__main__":
    print(f"Using device: {DEVICE}")