mamba / trainer.cli.py
flpelerin's picture
Update file trainer.cli.py
35a42a7
raw
history blame
909 Bytes
from argparse import ArgumentParser
import torch
from util import ConfigParser
from logger import Wandb
from trainer import Trainer
from dataset import Dataset
from tokenizer import Tokenizer
parser = ArgumentParser(
prog='Trainer implementation, using Pytorch',
description=''
)
if __name__ == '__main__':
parser.add_argument('-p', '--config_path')
args = parser.parse_args()
config = ConfigParser(args.config_path).config
dataset = Dataset(config.dataset)
tokenizer = Tokenizer()
tokenizer.train(dataset.text, max_length=config.tokenizer.max_length)
ids = tokenizer.c_encode(dataset.text)
config.model.params.vocab_size = tokenizer.vocab_size
batches, num_batches = dataset.batch(ids)
config.trainer.num_batches = num_batches
print(f"batches: {num_batches}")
trainer = Trainer(config)
trainer.train(dataset)