| | import argparse
|
| | import json
|
| | import os
|
| | import sys
|
| |
|
| | import torch
|
| | import transformers
|
| | import torch.distributed as dist
|
| | from torch.utils.data.distributed import DistributedSampler
|
| | from torch.nn.parallel import DistributedDataParallel
|
| | from peft import PeftModel
|
| | from torch.utils.data import DataLoader
|
| | from tqdm import tqdm
|
| | from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
|
| |
|
| | from utils import *
|
| | from collator import VanillaCollator, TestCollator
|
| | from prompt import all_prompt
|
| | from evaluate import get_topk_results, get_metrics_results
|
| | from rq_llama import *
|
| |
|
| | parser = argparse.ArgumentParser(description="RQ-Llama Evaluation")
|
| | parser = parse_global_args(parser)
|
| | parser = parse_dataset_args(parser)
|
| | parser = parse_test_args(parser)
|
| | args = parser.parse_args()
|
| |
|
| | set_seed(args.seed)
|
| | world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| | local_rank = int(os.environ.get("LOCAL_RANK") or 0)
|
| | torch.cuda.set_device(local_rank)
|
| | if local_rank == 0:
|
| | print(vars(args))
|
| |
|
| | dist.init_process_group(backend = "nccl", world_size = world_size, rank = local_rank)
|
| |
|
| | device_map = {"": local_rank}
|
| | device = torch.device("cuda",local_rank)
|
| |
|
| | rqllama = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map)
|
| | rqllama = DistributedDataParallel(rqllama, device_ids = [local_rank])
|
| |
|
| | if args.test_prompt_ids == "all":
|
| | if args.test_task.lower() == "seqrec":
|
| | prompt_ids = range(len(all_prompt["seqrec"]))
|
| | elif args.test_task.lower() == "itemsearch":
|
| | prompt_ids = range(len(all_prompt["itemsearch"]))
|
| | elif args.test_task.lower() == "fusionseqrec":
|
| | prompt_ids = range(len(all_prompt["fusionseqrec"]))
|
| | else:
|
| | prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")]
|
| |
|
| | test_data = load_test_dataset(args)
|
| | if local_rank == 0:
|
| | print("data num:", len(test_data))
|
| | ddp_sampler = DistributedSampler(test_data, num_replicas = world_size, rank = local_rank, drop_last = True)
|
| | collator = TestCollator(args, rqllama.module.tokenizer)
|
| | all_items = test_data.get_all_items()
|
| |
|
| | prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(rqllama.module.tokenizer)
|
| |
|
| | test_loader = DataLoader(
|
| | test_data,
|
| | batch_size = args.test_batch_size,
|
| | collate_fn = collator,
|
| | sampler = ddp_sampler,
|
| | num_workers = 2,
|
| | pin_memory = True
|
| | )
|
| |
|
| | rqllama.eval()
|
| |
|
| | metrics = args.metrics.split(",")
|
| | all_prompt_results = []
|
| |
|
| | with torch.no_grad():
|
| | for prompt_id in prompt_ids:
|
| | if local_rank == 0:
|
| | print("Start prompt: ",prompt_id)
|
| |
|
| | test_loader.dataset.set_prompt(prompt_id)
|
| | metrics_results = {}
|
| | total = 0
|
| | for step, batch in enumerate(tqdm(test_loader)):
|
| | inputs = batch[0].to(device)
|
| | targets = batch[1]
|
| | bs = len(targets)
|
| | num_beams = args.num_beams
|
| |
|
| | while True:
|
| | try:
|
| | output = rqllama.module.model.generate(
|
| | input_ids = inputs["input_ids"],
|
| | attention_mask = inputs["attention_mask"],
|
| | max_new_tokens = 10,
|
| | prefix_allowed_tokens_fn = prefix_allowed_tokens,
|
| | num_beams = num_beams,
|
| | num_return_sequences = num_beams,
|
| | output_scores = True,
|
| | return_dict_in_generate = True,
|
| | early_stopping = True,
|
| | )
|
| | break
|
| | except torch.cuda.OutOfMemoryError as e:
|
| | print("Out of memory!")
|
| | num_beams = num_beams -1
|
| | print("Beam:", num_beams)
|
| | except Exception:
|
| | raise RuntimeError
|
| |
|
| | output_ids = output["sequences"]
|
| | scores = output["sequences_scores"]
|
| |
|
| | output = rqllama.module.tokenizer.batch_decode(output_ids, skip_special_tokens = True)
|
| | topk_res = get_topk_results(
|
| | output,
|
| | scores,
|
| | targets,
|
| | num_beams,
|
| | all_items = all_items if args.filter_items else None
|
| | )
|
| |
|
| | bs_gather_list = [None for _ in range(world_size)]
|
| | dist.all_gather_object(obj = bs, object_list = bs_gather_list)
|
| | total += sum(bs_gather_list)
|
| | res_gather_list = [None for _ in range(world_size)]
|
| | dist.all_gather_object(obj = topk_res, object_list = res_gather_list)
|
| |
|
| | if local_rank == 0:
|
| | all_device_topk_res = []
|
| | for ga_res in res_gather_list:
|
| | all_device_topk_res += ga_res
|
| | batch_metrics_res = get_metrics_results(all_device_topk_res, metrics)
|
| | for m, res in batch_metrics_res.items():
|
| | if m not in metrics_results:
|
| | metrics_results[m] = res
|
| | else:
|
| | metrics_results[m] += res
|
| | if (step + 1) % 50 == 0:
|
| | temp = {}
|
| | for m in metrics_results:
|
| | temp[m] = metrics_results[m] / total
|
| | print(temp)
|
| | dist.barrier()
|
| |
|
| | if local_rank == 0:
|
| | for m in metrics_results:
|
| | metrics_results[m] = metrics_results[m] / total
|
| |
|
| | all_prompt_results.append(metrics_results)
|
| | print("======================================================")
|
| | print("Prompt {} results: ".format(prompt_id), metrics_results)
|
| | print("======================================================")
|
| | print("")
|
| | dist.barrier()
|
| | dist.barrier()
|
| |
|
| | if local_rank == 0:
|
| | mean_results = {}
|
| | min_results = {}
|
| | max_results = {}
|
| |
|
| | for m in metrics:
|
| | all_res = [_[m] for _ in all_prompt_results]
|
| | mean_results[m] = sum(all_res) / len(all_res)
|
| | min_results[m] = min(all_res)
|
| | max_results[m] = max(all_res)
|
| |
|
| | print("======================================================")
|
| | print("Mean results: ", mean_results)
|
| | print("Min results: ", min_results)
|
| | print("Max results: ", max_results)
|
| | print("======================================================")
|
| |
|
| | save_data={}
|
| | save_data["test_prompt_ids"] = args.test_prompt_ids
|
| | save_data["mean_results"] = mean_results
|
| | save_data["min_results"] = min_results
|
| | save_data["max_results"] = max_results
|
| | save_data["all_prompt_results"] = all_prompt_results
|
| |
|
| | with open(args.results_file, "w") as f:
|
| | json.dump(save_data, f, indent = 4)
|
| | print("Save file: ", args.results_file) |