Benchmark-Dual / utils.py
Junyin's picture
Add files using upload-large-folder tool
936459a verified
import json
import logging
import os
import random
import datetime
import numpy as np
import torch
from torch.utils.data import ConcatDataset
from data import *
# from data import SeqRecDataset, ItemFeatDataset, ItemSearchDataset, FusionSeqRecDataset, SeqRecTestDataset, PreferenceObtainDataset
from data_finetune import *
# from data_finetune import SeqRecFinetune, ItemFeatFinetune, ItemSearchFinetune, FusionSeqRecFinetune, PreferenceObtainFinetune
def parse_evaluate_args(parser):
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--base_model", type=str, default="../llama-7b/", help="basic model path")
parser.add_argument("--output_dir", type=str, default="./ckpt/", help="The output directory")
parser.add_argument("--data_path", type=str, default="",
help="data directory")
parser.add_argument("--tasks", type=str,
default='seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item',
help="Downstream tasks, separate by comma")
parser.add_argument("--train_data_sample_num", type=str, default="0,0,0,0,0,0,0,0,0",
help="the number of sampling data for each task")
parser.add_argument("--dataset", type=str, default="Instruments", help="Dataset name")
parser.add_argument("--index_file", type=str, default=".index.item.json", help="the item indices file")
parser.add_argument("--user_index_file", type=str, default=".index.user.json", help="the item indices file")
parser.add_argument("--dataloader_num_workers", type=int, default=0, help="dataloader num_workers")
parser.add_argument("--dataloader_prefetch_factor", type=int, default=2, help="dataloader prefetch_factor")
# arguments related to sequential task
parser.add_argument("--max_his_len", type=int, default=20,
help="the max number of items in history sequence, -1 means no limit")
parser.add_argument("--add_prefix", action="store_true", default=False,
help="whether add sequential prefix in history")
parser.add_argument("--his_sep", type=str, default=", ", help="The separator used for history")
parser.add_argument("--only_train_response", action="store_true", default=False,
help="whether only train on responses")
parser.add_argument("--train_prompt_sample_num", type=str, default="1,1,1,1,1,1,1,1,1",
help="the number of sampling prompts for each task")
parser.add_argument("--valid_prompt_id", type=int, default=0,
help="The prompt used for validation")
parser.add_argument("--sample_valid", action="store_true", default=True,
help="use sampled prompt for validation")
parser.add_argument("--valid_prompt_sample_num", type=int, default=2,
help="the number of sampling validation sequential recommendation prompts")
parser.add_argument("--ckpt_path", type=str, default="", help="The checkpoint path")
parser.add_argument("--lora", action="store_true", default=False)
parser.add_argument("--filter_items", action="store_true", default=False,
help="whether filter illegal items")
parser.add_argument("--results_file", type=str, default="./results/test-ddp.json", help="result output path")
parser.add_argument("--test_batch_size", type=int, default=1)
parser.add_argument("--num_beams", type=int, default=20)
parser.add_argument("--sample_num", type=int, default=-1,
help="test sample number, -1 represents using all test data")
parser.add_argument("--gpu_id", type=int, default=0,
help="GPU ID when testing with single GPU")
parser.add_argument("--test_prompt_ids", type=str, default="0",
help="test prompt ids, separate by comma. 'all' represents using all")
parser.add_argument("--metrics", type=str, default="hit@1,hit@5,hit@10,ndcg@5,ndcg@10",
help="test metrics, separate by comma")
parser.add_argument("--test_task", type=str, default="SeqRec")
return parser
def parse_finetune_args(parser):
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--base_model", type=str, default="../llama-7b/", help="basic model path")
parser.add_argument("--output_dir", type=str, default="./ckpt/", help="The output directory")
parser.add_argument("--data_path", type=str, default="",
help="data directory")
parser.add_argument("--tasks", type=str,
default='seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item',
help="Downstream tasks, separate by comma")
parser.add_argument("--train_data_sample_num", type=str, default="0,0,0,0,0,0,0,0,0",
help="the number of sampling data for each task")
parser.add_argument("--dataset", type=str, default="Instruments", help="Dataset name")
parser.add_argument("--index_file", type=str, default=".index.json", help="item indices file")
parser.add_argument("--user_index_file", type=str, default=".user-index.json", help="user indices file")
parser.add_argument("--dataloader_num_workers", type=int, default=0, help="dataloader num_workers")
parser.add_argument("--dataloader_prefetch_factor", type=int, default=2, help="dataloader prefetch_factor")
parser.add_argument("--max_his_len", type=int, default=20,
help="the max number of items in history sequence, -1 means no limit")
parser.add_argument("--add_prefix", action="store_true", default=False,
help="whether add sequential prefix in history")
parser.add_argument("--his_sep", type=str, default=", ", help="The separator used for history")
parser.add_argument("--only_train_response", action="store_true", default=False,
help="whether only train on responses")
parser.add_argument("--train_prompt_sample_num", type=str, default="1,1,1,1,1,1,1,1,1",
help="the number of sampling prompts for each task")
parser.add_argument("--valid_prompt_id", type=int, default=0,
help="The prompt used for validation")
parser.add_argument("--sample_valid", action="store_true", default=True,
help="use sampled prompt for validation")
parser.add_argument("--valid_prompt_sample_num", type=int, default=2,
help="the number of sampling validation sequential recommendation prompts")
parser.add_argument("--optim", type=str, default="adamw_torch", help='The name of the optimizer')
parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--per_device_batch_size", type=int, default=8)
parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
parser.add_argument("--logging_step", type=int, default=10)
parser.add_argument("--model_max_length", type=int, default=2048)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--lora_r", type=int, default=8)
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--lora_dropout", type=float, default=0.05)
parser.add_argument("--lora_target_modules", type=str,
default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj", help="separate by comma")
parser.add_argument("--lora_modules_to_save", type=str,
default="embed_tokens,lm_head", help="separate by comma")
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="either training checkpoint or final adapter")
parser.add_argument("--warmup_ratio", type=float, default=0.01)
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
parser.add_argument("--save_and_eval_strategy", type=str, default="epoch")
parser.add_argument("--save_and_eval_steps", type=int, default=1000)
parser.add_argument("--fp16", action="store_true", default=False)
parser.add_argument("--bf16", action="store_true", default=False)
parser.add_argument("--deepspeed", type=str, default="./config/ds_z3_bf16.json")
parser.add_argument("--remove_unused_columns", action="store_true", default=False, help='if remove unused columns')
parser.add_argument("--reindex", type = int, default = 0)
# parser.add_argument("--user_reindex", type = int, default = 0)
parser.add_argument("--ckpt_path", type=str, default="")
return parser
def load_finetune_datasets(args):
tasks = args.tasks.split(",")
train_prompt_sample_num = [int(_) for _ in args.train_prompt_sample_num.split(",")]
assert len(tasks) == len(train_prompt_sample_num), "prompt sample number does not match task number"
train_data_sample_num = [int(_) for _ in args.train_data_sample_num.split(",")]
assert len(tasks) == len(train_data_sample_num), "data sample number does not match task number"
train_datasets = []
for task, prompt_sample_num,data_sample_num in zip(tasks,train_prompt_sample_num,train_data_sample_num):
if task.lower() == "seqrec":
dataset = SeqRecFinetune(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() == "item2index" or task.lower() == "index2item":
dataset = ItemFeatFinetune(args, task=task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() == "fusionseqrec":
dataset = FusionSeqRecFinetune(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() == "itemsearch":
dataset = ItemSearchFinetune(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() == "preferenceobtain":
dataset = PreferenceObtainFinetune(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() == "usersearch":
dataset = UserSearchFinetune(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() in ["user2pref", "pref2user"]:
dataset = UserFeatFinetune(args, task = task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
else:
raise NotImplementedError
train_datasets.append(dataset)
train_data = ConcatDataset(train_datasets)
valid_data = SeqRecFinetune(args, "valid", args.valid_prompt_sample_num)
return train_data, valid_data
# def load_finetune_datasets(args):
# tasks = args.tasks.split(",")
# train_prompt_sample_num = [int(_) for _ in args.train_prompt_sample_num.split(",")]
# assert len(tasks) == len(train_prompt_sample_num), "prompt sample number does not match task number"
# train_data_sample_num = [int(_) for _ in args.train_data_sample_num.split(",")]
# assert len(tasks) == len(train_data_sample_num), "data sample number does not match task number"
# train_datasets = []
# for task, prompt_sample_num,data_sample_num in zip(tasks,train_prompt_sample_num,train_data_sample_num):
# if task.lower() == "seqrec":
# dataset = SeqRecFinetune(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
# elif task.lower() == "item2index" or task.lower() == "index2item":
# dataset = ItemFeatFinetune(args, task=task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
# elif task.lower() in ["inters2title", "inters2description", "intertitles2item"]:
# dataset = FusionSeqRecFinetune(args, task=task.lower(), mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
# elif task.lower() in ["itemsearch", "query2item"]:
# dataset = ItemSearchFinetune(args, task=task.lower(),mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
# elif task.lower() == "preferenceobtain":
# dataset = PreferenceObtainFinetune(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
# else:
# raise NotImplementedError
# train_datasets.append(dataset)
# train_data = ConcatDataset(train_datasets)
# valid_data = SeqRecDataset(args,"valid",args.valid_prompt_sample_num)
# return train_data, valid_data
def parse_global_args(parser):
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--base_model", type=str,
default="../llama-7b/",
help="basic model path")
parser.add_argument("--output_dir", type=str,
default="./ckpt/",
help="The output directory")
return parser
def parse_dataset_args(parser):
parser.add_argument("--data_path", type=str, default="",
help="data directory")
parser.add_argument("--tasks", type=str,
default='seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item',
help="Downstream tasks, separate by comma")
parser.add_argument("--train_data_sample_num", type=str, default="0,0,0,0,0,0,0,0,0",
help="the number of sampling data for each task")
parser.add_argument("--dataset", type=str, default="Instruments", help="Dataset name")
parser.add_argument("--index_file", type=str, default=".index.json", help="the item indices file")
parser.add_argument("--user_index_file", type=str, default=".user-index.json", help="user indices file")
parser.add_argument("--dataloader_num_workers", type=int, default=0, help="dataloader num_workers")
parser.add_argument("--dataloader_prefetch_factor", type=int, default=2, help="dataloader prefetch_factor")
# arguments related to sequential task
parser.add_argument("--max_his_len", type=int, default=20,
help="the max number of items in history sequence, -1 means no limit")
parser.add_argument("--add_prefix", action="store_true", default=False,
help="whether add sequential prefix in history")
parser.add_argument("--his_sep", type=str, default=", ", help="The separator used for history")
parser.add_argument("--only_train_response", action="store_true", default=False,
help="whether only train on responses")
parser.add_argument("--train_prompt_sample_num", type=str, default="1,1,1,1,1,1,1,1,1",
help="the number of sampling prompts for each task")
parser.add_argument("--valid_prompt_id", type=int, default=0,
help="The prompt used for validation")
parser.add_argument("--sample_valid", action="store_true", default=True,
help="use sampled prompt for validation")
parser.add_argument("--valid_prompt_sample_num", type=int, default=2,
help="the number of sampling validation sequential recommendation prompts")
return parser
def parse_train_args(parser):
parser.add_argument("--optim", type=str, default="adamw_torch", help='The name of the optimizer')
parser.add_argument("--epochs", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--per_device_batch_size", type=int, default=8)
parser.add_argument("--gradient_accumulation_steps", type=int, default=2)
parser.add_argument("--logging_step", type=int, default=10)
parser.add_argument("--model_max_length", type=int, default=2048)
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--lora_r", type=int, default=8)
parser.add_argument("--lora_alpha", type=int, default=32)
parser.add_argument("--lora_dropout", type=float, default=0.05)
parser.add_argument("--lora_target_modules", type=str,
default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj", help="separate by comma")
parser.add_argument("--lora_modules_to_save", type=str,
default="embed_tokens,lm_head", help="separate by comma")
parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="either training checkpoint or final adapter")
parser.add_argument("--warmup_ratio", type=float, default=0.01)
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
parser.add_argument("--save_and_eval_strategy", type=str, default="epoch")
parser.add_argument("--save_and_eval_steps", type=int, default=1000)
parser.add_argument("--fp16", action="store_true", default=False)
parser.add_argument("--bf16", action="store_true", default=False)
parser.add_argument("--deepspeed", type=str, default="./config/ds_z3_bf16.json")
parser.add_argument("--remove_unused_columns", action="store_true", default=False, help='if remove unused columns')
return parser
def parse_rqvae_args(parser):
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
# parser.add_argument('--epochs', type=int, default=5000, help='number of epochs')
parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
parser.add_argument('--num_workers', type=int, default=4, )
parser.add_argument('--eval_step', type=int, default=50, help='eval step')
parser.add_argument('--learner', type=str, default="AdamW", help='optimizer')
# parser.add_argument("--data_path", type=str,
# default="../data/Games/Games.emb-llama-td.npy",
# help="Input data path.")
# parser.add_argument('--weight_decay', type=float, default=1e-4, help='l2 regularization weight')
parser.add_argument("--dropout_prob", type=float, default=0.0, help="dropout ratio")
parser.add_argument("--bn", type=bool, default=False, help="use bn or not")
parser.add_argument("--loss_type", type=str, default="mse", help="loss_type")
parser.add_argument("--kmeans_init", type=bool, default=False, help="use kmeans_init or not")
parser.add_argument("--kmeans_iters", type=int, default=100, help="max kmeans iters")
parser.add_argument('--sk_epsilons', type=float, nargs='+', default=[0.0, 0.0, 0.0, 0.0], help="sinkhorn epsilons")
parser.add_argument("--sk_iters", type=int, default=50, help="max sinkhorn iters")
parser.add_argument("--device", type=str, default="cuda:1", help="gpu or cpu")
parser.add_argument('--num_emb_list', type=int, nargs='+', default=[256,256,256,256], help='emb num of every vq')
parser.add_argument('--e_dim', type=int, default=32, help='vq codebook embedding size')
parser.add_argument('--quant_loss_weight', type=float, default=1.0, help='vq quantion loss weight')
parser.add_argument('--layers', type=int, nargs='+', default=[2048,1024,512,256,128,64], help='hidden sizes of every layer')
parser.add_argument("--ckpt_path", type=str, default="", help="output directory for model")
parser.add_argument("--warmup", type=int, default=5, help="epochs for warmup")
parser.add_argument("--item_model", type=str, default="", help="")
parser.add_argument("--user_model", type=str, default="", help="")
return parser
def parse_test_args(parser):
parser.add_argument("--ckpt_path", type=str,
default="",
help="The checkpoint path")
parser.add_argument("--lora", action="store_true", default=False)
parser.add_argument("--filter_items", action="store_true", default=False,
help="whether filter illegal items")
parser.add_argument("--results_file", type=str,
default="./results/test-ddp.json",
help="result output path")
parser.add_argument("--test_batch_size", type=int, default=1)
parser.add_argument("--num_beams", type=int, default=20)
parser.add_argument("--sample_num", type=int, default=-1,
help="test sample number, -1 represents using all test data")
parser.add_argument("--gpu_id", type=int, default=0,
help="GPU ID when testing with single GPU")
parser.add_argument("--test_prompt_ids", type=str, default="0",
help="test prompt ids, separate by comma. 'all' represents using all")
parser.add_argument("--metrics", type=str, default="hit@1,hit@5,hit@10,ndcg@5,ndcg@10",
help="test metrics, separate by comma")
parser.add_argument("--test_task", type=str, default="SeqRec")
return parser
def get_local_time():
cur = datetime.datetime.now()
cur = cur.strftime("%b-%d-%Y_%H-%M-%S")
return cur
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.enabled = False
def ensure_dir(dir_path):
os.makedirs(dir_path, exist_ok=True)
def load_datasets(args):
tasks = args.tasks.split(",")
train_prompt_sample_num = [int(_) for _ in args.train_prompt_sample_num.split(",")]
assert len(tasks) == len(train_prompt_sample_num), "prompt sample number does not match task number"
train_data_sample_num = [int(_) for _ in args.train_data_sample_num.split(",")]
assert len(tasks) == len(train_data_sample_num), "data sample number does not match task number"
train_datasets = []
for task, prompt_sample_num,data_sample_num in zip(tasks,train_prompt_sample_num,train_data_sample_num):
if task.lower() == "seqrec":
dataset = SeqRecDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() == "item2index" or task.lower() == "index2item":
dataset = ItemFeatDataset(args, task=task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() in ["inters2title", "inters2description", "intertitles2item"]:
dataset = FusionSeqRecDataset(args, task=task.lower(), mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() in ["itemsearch", "query2item"]:
dataset = ItemSearchDataset(args, task=task.lower(),mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() == "preferenceobtain":
dataset = PreferenceObtainDataset(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() == 'usersearch':
dataset = UserSearchDataset(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
elif task.lower() in ["pref2user", "user2pref"]:
dataset = UserFeatDataset(args, task = task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num)
else:
raise NotImplementedError
train_datasets.append(dataset)
train_data = ConcatDataset(train_datasets)
valid_data = SeqRecDataset(args,"valid",args.valid_prompt_sample_num)
return train_data, valid_data
def load_test_dataset(args):
if args.test_task.lower() == "seqrec":
test_data = SeqRecFinetune(args, mode="test", sample_num=args.sample_num)
elif args.test_task.lower() == "itemsearch":
test_data = ItemSearchDataset(args, mode="test", sample_num=args.sample_num)
elif args.test_task.lower() == "fusionseqrec":
test_data = FusionSeqRecDataset(args, mode="test", sample_num=args.sample_num)
else:
raise NotImplementedError
return test_data
# def load_test_dataset(args):
# if args.test_task.lower() == "seqrec":
# test_data = SeqRecDataset(args, mode="test", sample_num=args.sample_num)
# elif args.test_task.lower() == "itemsearch":
# test_data = ItemSearchDataset(args, mode="test", sample_num=args.sample_num)
# elif args.test_task.lower() == "fusionseqrec":
# test_data = FusionSeqRecDataset(args, mode="test", sample_num=args.sample_num)
# else:
# raise NotImplementedError
# return test_data
def load_json(file):
with open(file, 'r') as f:
data = json.load(f)
return data