from argparse import ArgumentParser import torch from util import ConfigParser from logger import Wandb from trainer import Trainer from dataset import Dataset from tokenizer import Tokenizer parser = ArgumentParser( prog='Trainer implementation, using Pytorch', description='' ) if __name__ == '__main__': parser.add_argument('-p', '--config_path') args = parser.parse_args() config = ConfigParser(args.config_path).config dataset = Dataset(config.dataset) tokenizer = Tokenizer() tokenizer.train(dataset.text, max_length=config.tokenizer.max_length) ids = tokenizer.c_encode(dataset.text) config.model.params.vocab_size = tokenizer.vocab_size batches, num_batches = dataset.batch(ids) config.trainer.num_batches = num_batches print(f"batches: {num_batches}") trainer = Trainer(config) trainer.train(dataset)