import argparse import json import os import sys import torch import transformers import torch.distributed as dist from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel from peft import PeftModel from torch.utils.data import DataLoader from tqdm import tqdm from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig from utils import * from collator import TestCollator from prompt import all_prompt from evaluate import get_topk_results, get_metrics_results def test_ddp(args): set_seed(args.seed) world_size = int(os.environ.get("WORLD_SIZE", 1)) local_rank = int(os.environ.get("LOCAL_RANK") or 0) torch.cuda.set_device(local_rank) if local_rank == 0: print(vars(args)) dist.init_process_group(backend="nccl", world_size=world_size, rank=local_rank) device_map = {"": local_rank} device = torch.device("cuda",local_rank) tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path) args.lora=True if args.lora: model = LlamaForCausalLM.from_pretrained( args.base_model, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=device_map, ) model.resize_token_embeddings(len(tokenizer)) model = PeftModel.from_pretrained( model, args.ckpt_path, torch_dtype=torch.float16, device_map=device_map, ) else: model = LlamaForCausalLM.from_pretrained( args.ckpt_path, torch_dtype=torch.float16, low_cpu_mem_usage=True, device_map=device_map, ) # assert model.config.vocab_size == len(tokenizer) model = DistributedDataParallel(model, device_ids=[local_rank]) if args.test_prompt_ids == "all": if args.test_task.lower() == "seqrec": prompt_ids = range(len(all_prompt["seqrec"])) elif args.test_task.lower() == "itemsearch": prompt_ids = range(len(all_prompt["itemsearch"])) elif args.test_task.lower() == "fusionseqrec": prompt_ids = range(len(all_prompt["fusionseqrec"])) else: prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")] test_data = load_test_dataset(args) ddp_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=local_rank, drop_last=True) test_data = load_test_dataset(args) collator = TestCollator(args, tokenizer) all_items = test_data.get_all_items() prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(tokenizer) test_loader = DataLoader(test_data, batch_size=args.test_batch_size, collate_fn=collator, sampler=ddp_sampler, num_workers=2, pin_memory=True) if local_rank == 0: print("data num:", len(test_data)) model.eval() metrics = args.metrics.split(",") all_prompt_results = [] with torch.no_grad(): for prompt_id in prompt_ids: if local_rank == 0: print("Start prompt: ",prompt_id) test_loader.dataset.set_prompt(prompt_id) metrics_results = {} total = 0 for step, batch in enumerate(tqdm(test_loader)): inputs = batch[0].to(device) targets = batch[1] bs = len(targets) num_beams = args.num_beams while True: try: output = model.module.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_new_tokens=10, prefix_allowed_tokens_fn=prefix_allowed_tokens, num_beams=num_beams, num_return_sequences=num_beams, output_scores=True, return_dict_in_generate=True, early_stopping=True, ) break except torch.cuda.OutOfMemoryError as e: print("Out of memory!") num_beams = num_beams -1 print("Beam:", num_beams) except Exception: raise RuntimeError output_ids = output["sequences"] scores = output["sequences_scores"] output = tokenizer.batch_decode( output_ids, skip_special_tokens=True ) topk_res = get_topk_results(output, scores, targets, num_beams, all_items=all_items if args.filter_items else None) bs_gather_list = [None for _ in range(world_size)] dist.all_gather_object(obj=bs, object_list=bs_gather_list) total += sum(bs_gather_list) res_gather_list = [None for _ in range(world_size)] dist.all_gather_object(obj=topk_res, object_list=res_gather_list) if local_rank == 0: all_device_topk_res = [] for ga_res in res_gather_list: all_device_topk_res += ga_res batch_metrics_res = get_metrics_results(all_device_topk_res, metrics) for m, res in batch_metrics_res.items(): if m not in metrics_results: metrics_results[m] = res else: metrics_results[m] += res if (step + 1) % 50 == 0: temp = {} for m in metrics_results: temp[m] = metrics_results[m] / total print(temp) dist.barrier() if local_rank == 0: for m in metrics_results: metrics_results[m] = metrics_results[m] / total all_prompt_results.append(metrics_results) print("======================================================") print("Prompt {} results: ".format(prompt_id), metrics_results) print("======================================================") print("") dist.barrier() dist.barrier() if local_rank == 0: mean_results = {} min_results = {} max_results = {} for m in metrics: all_res = [_[m] for _ in all_prompt_results] mean_results[m] = sum(all_res)/len(all_res) min_results[m] = min(all_res) max_results[m] = max(all_res) print("======================================================") print("Mean results: ", mean_results) print("Min results: ", min_results) print("Max results: ", max_results) print("======================================================") save_data={} save_data["test_prompt_ids"] = args.test_prompt_ids save_data["mean_results"] = mean_results save_data["min_results"] = min_results save_data["max_results"] = max_results save_data["all_prompt_results"] = all_prompt_results with open(args.results_file, "w") as f: json.dump(save_data, f, indent=4) print("Save file: ", args.results_file) import smtplib from email.mime.text import MIMEText mail_host = 'smtp.qq.com' mail_code = 'ouzplpngooqndjcb' sender = '1849334588@qq.com' receiver = 'esperanto1949@foxmail.com' task = '[v67: evaluate lcrec]' message = MIMEText('Task {task} Finished'.format(task = task), 'plain', 'utf-8') message['Subject'] = 'Auto Email' message['From'] = sender message['To'] = receiver server = smtplib.SMTP_SSL("smtp.qq.com", 465) server.login(sender, mail_code) server.sendmail(sender, receiver, message.as_string()) server.quit() if __name__ == "__main__": parser = argparse.ArgumentParser(description="LLMRec_test") parser = parse_global_args(parser) parser = parse_dataset_args(parser) parser = parse_test_args(parser) args = parser.parse_args() test_ddp(args)