Spaces:
Sleeping
Sleeping
File size: 1,425 Bytes
f7a8d72 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 | import torch
from transformers import TrainingArguments, Trainer
from src.config import Config
from src.dataset import get_dataset
from src.model import get_model
# =========================
# LOAD DATA
# =========================
full_dataset = get_dataset()
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
full_dataset, [train_size, val_size]
)
print(f"Train size: {len(train_dataset)}")
print(f"Validation size: {len(val_dataset)}")
# =========================
# MODEL
# =========================
model = get_model()
# =========================
# TRAINING CONFIG (FULLY COMPATIBLE)
# =========================
training_args = TrainingArguments(
output_dir="outputs/model",
per_device_train_batch_size=Config.BATCH_SIZE,
per_device_eval_batch_size=Config.BATCH_SIZE,
num_train_epochs=Config.EPOCHS,
learning_rate=Config.LEARNING_RATE,
logging_steps=50,
save_steps=500
)
# =========================
# TRAINER
# =========================
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset
)
# =========================
# TRAIN
# =========================
trainer.train()
# =========================
# SAVE MODEL
# =========================
trainer.save_model("outputs/model")
print("Model saved successfully!") |