File size: 1,425 Bytes
f7a8d72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from transformers import TrainingArguments, Trainer

from src.config import Config
from src.dataset import get_dataset
from src.model import get_model


# =========================
# LOAD DATA
# =========================
full_dataset = get_dataset()

train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = torch.utils.data.random_split(
    full_dataset, [train_size, val_size]
)

print(f"Train size: {len(train_dataset)}")
print(f"Validation size: {len(val_dataset)}")


# =========================
# MODEL
# =========================
model = get_model()


# =========================
# TRAINING CONFIG (FULLY COMPATIBLE)
# =========================
training_args = TrainingArguments(
    output_dir="outputs/model",
    per_device_train_batch_size=Config.BATCH_SIZE,
    per_device_eval_batch_size=Config.BATCH_SIZE,
    num_train_epochs=Config.EPOCHS,
    learning_rate=Config.LEARNING_RATE,
    logging_steps=50,
    save_steps=500
)


# =========================
# TRAINER
# =========================
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset
)


# =========================
# TRAIN
# =========================
trainer.train()


# =========================
# SAVE MODEL
# =========================
trainer.save_model("outputs/model")
print("Model saved successfully!")