| import argparse |
| import random |
| import torch |
| import numpy as np |
| from time import time |
| import logging |
|
|
| from torch.utils.data import DataLoader |
|
|
| from datasets import EmbDataset |
| from models.rqvae import RQVAE |
| from trainer import Trainer |
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Index") |
|
|
| parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') |
| parser.add_argument('--epochs', type=int, default=5000, help='number of epochs') |
| parser.add_argument('--batch_size', type=int, default=1024, help='batch size') |
| parser.add_argument('--num_workers', type=int, default=4, ) |
| parser.add_argument('--eval_step', type=int, default=50, help='eval step') |
| parser.add_argument('--learner', type=str, default="AdamW", help='optimizer') |
| parser.add_argument("--data_path", type=str, |
| default="../data/Games/Games.emb-llama-td.npy", |
| help="Input data path.") |
|
|
| parser.add_argument('--weight_decay', type=float, default=1e-4, help='l2 regularization weight') |
| parser.add_argument("--dropout_prob", type=float, default=0.0, help="dropout ratio") |
| parser.add_argument("--bn", type=bool, default=False, help="use bn or not") |
| parser.add_argument("--loss_type", type=str, default="mse", help="loss_type") |
| parser.add_argument("--kmeans_init", type=bool, default=True, help="use kmeans_init or not") |
| parser.add_argument("--kmeans_iters", type=int, default=100, help="max kmeans iters") |
| parser.add_argument('--sk_epsilons', type=float, nargs='+', default=[0.0, 0.0, 0.0], help="sinkhorn epsilons") |
| parser.add_argument("--sk_iters", type=int, default=50, help="max sinkhorn iters") |
|
|
| parser.add_argument("--device", type=str, default="cuda:1", help="gpu or cpu") |
|
|
| parser.add_argument('--num_emb_list', type=int, nargs='+', default=[256,256,256], help='emb num of every vq') |
| parser.add_argument('--e_dim', type=int, default=32, help='vq codebook embedding size') |
| parser.add_argument('--quant_loss_weight', type=float, default=1.0, help='vq quantion loss weight') |
| parser.add_argument('--layers', type=int, nargs='+', default=[2048,1024,512,256,128,64], help='hidden sizes of every layer') |
|
|
| parser.add_argument("--ckpt_dir", type=str, default="", help="output directory for model") |
|
|
| return parser.parse_args() |
|
|
|
|
| if __name__ == '__main__': |
| """fix the random seed""" |
| seed = 2023 |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| args = parse_args() |
| print(args) |
|
|
| logging.basicConfig(level=logging.DEBUG) |
|
|
| """build dataset""" |
| data = EmbDataset(args.data_path) |
| model = RQVAE(in_dim=data.dim, |
| num_emb_list=args.num_emb_list, |
| e_dim=args.e_dim, |
| layers=args.layers, |
| dropout_prob=args.dropout_prob, |
| bn=args.bn, |
| loss_type=args.loss_type, |
| quant_loss_weight=args.quant_loss_weight, |
| kmeans_init=args.kmeans_init, |
| kmeans_iters=args.kmeans_iters, |
| sk_epsilons=args.sk_epsilons, |
| sk_iters=args.sk_iters, |
| ) |
| print(model) |
| data_loader = DataLoader(data,num_workers=args.num_workers, |
| batch_size=args.batch_size, shuffle=True, |
| pin_memory=True) |
| trainer = Trainer(args,model) |
| best_loss, best_collision_rate = trainer.fit(data_loader) |
|
|
| print("Best Loss",best_loss) |
| print("Best Collision Rate", best_collision_rate) |
|
|
|
|