hmnshudhmn24's picture
Upload 18 files
9a1ddeb verified
from transformers import Trainer, TrainingArguments
def get_trainer(model, tokenizer, dataset, cfg):
args = TrainingArguments(
output_dir=cfg["output_dir"],
per_device_train_batch_size=cfg["batch_size"],
num_train_epochs=cfg["epochs"],
learning_rate=cfg["learning_rate"],
logging_steps=cfg["logging_steps"],
evaluation_strategy=cfg["evaluation_strategy"],
save_strategy="epoch"
)
return Trainer(
model=model,
args=args,
train_dataset=dataset["train"],
eval_dataset=dataset["validation"],
tokenizer=tokenizer
)