| from argparse import ArgumentParser | |
| import torch | |
| from util import ConfigParser | |
| from logger import Wandb | |
| from trainer import Trainer | |
| from dataset import Dataset | |
| from tokenizer import Tokenizer | |
| parser = ArgumentParser( | |
| prog='Trainer implementation, using Pytorch', | |
| description='' | |
| ) | |
| if __name__ == '__main__': | |
| parser.add_argument('-p', '--config_path') | |
| args = parser.parse_args() | |
| config = ConfigParser(args.config_path).config | |
| dataset = Dataset(config.dataset) | |
| tokenizer = Tokenizer() | |
| tokenizer.train(dataset.text, max_length=config.tokenizer.max_length) | |
| ids = tokenizer.c_encode(dataset.text) | |
| config.model.params.vocab_size = tokenizer.vocab_size | |
| batches, num_batches = dataset.batch(ids) | |
| config.trainer.num_batches = num_batches | |
| print(f"batches: {num_batches}") | |
| trainer = Trainer(config) | |
| trainer.train(dataset) |