diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..caf4c186a99ef8bc906290e77f2411e4b8caac57 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +asset/model.jpg filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a2c9818803b994107c80c85192cadc2d52dcb86f --- /dev/null +++ b/README.md @@ -0,0 +1,124 @@ +# LC-Rec + +This is the official PyTorch implementation for the paper: + +> [Adapting Large Language Models by Integrating Collaborative Semantics for Recommendation](https://arxiv.org/abs/2311.09049) + +## Overview + +We propose **LC-Rec**, a new approach to integrate **L**anguage and **C**ollaborative semantics for improving LLMs in **Rec**ommender systems. To tackle the large gap between the language semantics modeled by LLMs and collaborative semantics implied by recommender systems, we make two major contributions in two aspects. For item indexing, we design a learning-based vector quantization method with uniform semantic mapping, which can assign meaningful and non-conflicting IDs (called item indices) for items. For alignment tuning, we propose a series of specially designed tuning tasks to enhance the integration of collaborative semantics in LLMs. Our fine-tuning tasks enforce LLMs to deeply integrate language and collaborative semantics (characterized by the learned item indices), so as to achieve an effective adaptation to recommender systems. + +![model](./asset/model.jpg) + +## Requirements + +``` +torch==1.13.1+cu117 +accelerate +bitsandbytes +deepspeed +evaluate +peft +sentencepiece +tqdm +transformers +``` + +## Model Checkpoint + +The delta weights on the three datasets can be downloaded from huggingface hub ([Instruments](https://huggingface.co/bwzheng0324/lc-rec-instruments-delta), [Arts](https://huggingface.co/bwzheng0324/lc-rec-arts-delta), [Games](https://huggingface.co/bwzheng0324/lc-rec-games-delta)). After downloading, you can add our deltas to the original LLaMA weights to obtain LC-Rec weights: + +1. Get the original [LLaMA](https://huggingface.co/huggyllama/llama-7b) weights. +2. Use the following scripts to get LC-Rec weights by applying our delta. + +```shell +python -m convert/merge_delta.py \ + --base-model-path /path/to/llama-7b \ + --target-model-path /path/output/lc-rec \ + --delta-path bwzheng0324/lc-rec-games-delta +``` + +## Dataset + +We use three datasets in our paper, all of which have been uploaded to [Google Drive](https://drive.google.com/drive/folders/1RcJ2M1l5zWPHYuGd9l5Gibcs5w5aI3y6?usp=sharing) + +## Train + +The detailed scripts for all three datasets are in `run.sh`: + +```shell +DATASET=Games +BASE_MODEL=huggyllama/llama-7b +DATA_PATH=./data +OUTPUT_DIR=./ckpt/$DATASET/ + +torchrun --nproc_per_node=8 --master_port=3324 finetune.py \ + --base_model $BASE_MODEL \ + --output_dir $OUTPUT_DIR \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --per_device_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --learning_rate 5e-5 \ + --epochs 4 \ + --weight_decay 0.01 \ + --save_and_eval_strategy epoch \ + --deepspeed ./config/ds_z3_bf16.json \ + --bf16 \ + --only_train_response \ + --tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \ + --train_prompt_sample_num 1,1,1,1,1,1 \ + --train_data_sample_num 0,0,0,100000,0,0 \ + --index_file .index.json + + +cd convert +nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 & +cd .. +``` + +## Test + +Test with a single GPU: + +```shell +DATASET=Games +DATA_PATH=./data +OUTPUT_DIR=./ckpt/$DATASET/ +RESULTS_FILE=./results/$DATASET/result.json + +python test.py \ + --gpu_id 0 \ + --ckpt_path $CKPT_PATH \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --results_file $RESULTS_FILE \ + --test_batch_size 1 \ + --num_beams 20 \ + --test_prompt_ids all \ + --index_file .index.json +``` + +Test with multiple GPUs: + +```shell +DATASET=Games +DATA_PATH=./data +OUTPUT_DIR=./ckpt/$DATASET/ +RESULTS_FILE=./results/$DATASET/result.json + +torchrun --nproc_per_node=8 --master_port=4324 test_ddp.py \ + --ckpt_path $CKPT_PATH \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --results_file $RESULTS_FILE \ + --test_batch_size 1 \ + --num_beams 20 \ + --test_prompt_ids all \ + --index_file .index.json +``` + +## Acknowledgement + +The implementation is based on [HuggingFace](https://github.com/huggingface/transformers). + diff --git a/asset/model.jpg b/asset/model.jpg new file mode 100644 index 0000000000000000000000000000000000000000..68c3b8342a2a8af1ea6e9b4b17beee66b1a3fa81 --- /dev/null +++ b/asset/model.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:52223d0ef7f3701a6e40db9997e78c0a7f0d6bfce7965b9f27637e0e25fd1097 +size 1126259 diff --git a/collator.py b/collator.py new file mode 100644 index 0000000000000000000000000000000000000000..ed42396d75edab69a1422950ba4a2e05aecc4e10 --- /dev/null +++ b/collator.py @@ -0,0 +1,273 @@ +import torch +import copy +import argparse +from dataclasses import dataclass + +import transformers +import math +from torch.utils.data import Sampler +import torch.distributed as dist +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, T5Tokenizer, T5Config, T5ForConditionalGeneration + +class VanillaCollator(object): + def __init__(self, args, tokenizer): + self.args = args + self.tokenizer = tokenizer + def __call__(self, data): + # print('collator data:',data) + ''' + [{ + 'input_ids': + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n + ### Instruction:\n + Access the user's historical item interaction records: {inters}. + Your objective is to describe the next potential item for him, taking into account his past interactions.\n\n + ### Response:", + 'labels': + "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n + ### Instruction:\n + Access the user's historical item interaction records: {inters}. + Your objective is to describe the next potential item for him, taking into account his past interactions.\n\n + ### Response: + Dunlop guitar picks are a top choice of today's pro musician! Dunlop's wide variety of gauges, shapes, sizes and materials + allows the player to select the exact pick for his/her own particular style of playing. From classic country to nu-metal, + every great player knows that their pick is an integral part of their tone, and Dunlop guitar picks are the picks that more + pros rely on in the studio or on stage. Picks are a grossly underrated accessory. Don't sacrifice your tone...pick Dunlop guitar picks!.", + 'inters': '341,2804,3895,3893,7064', + 'item': 'placeholder', + 'task': 'inters2description' + }, + { + 'input_ids': + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n + ### Instruction:\n + Based on the user\'s historical interactions with the following items: {inters}. + You can infer his preference by observing the historical interactions: "The user\'s short-term preferences have shift to heavier picks, + suggesting that He is looking for a heavier sound.". Now the user wants a new item and searches for: "I like the durability and + effectiveness of the picks.". Please select a suitable item that matches his preference and search intent.\n\n + ### Response:', + 'labels': + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n + ### Instruction:\n + Based on the user\'s historical interactions with the following items: {inters}. + You can infer his preference by observing the historical interactions: "The user\'s short-term preferences have shift to heavier picks, + suggesting that He is looking for a heavier sound.". Now the user wants a new item and searches for: "I like the durability and + effectiveness of the picks.". Please select a suitable item that matches his preference and search intent.\n\n + ### Response:{item}', + 'inters': '122,469,8918', + 'item': '7140', + 'task': 'itemsearch' + }] + ''' + dict_data = { + 'input_ids': [], + 'labels': [], + 'inters': [], + 'item': [], + 'users': [], + 'user': [], + 'task': [] + } + + for d in data: + for k in dict_data.keys(): + if k == 'labels': + dict_data[k].append(d[k] + self.tokenizer.eos_token) + else: + dict_data[k].append(d[k]) + + return dict_data + +class Collator(object): + + def __init__(self, args, tokenizer): + self.args = args + self.only_train_response = args.only_train_response + self.tokenizer = tokenizer + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = self.tokenizer.unk_token_id + # print(self.tokenizer.model_max_length) + + def __call__(self, batch): + + input_texts = [d["input_ids"] for d in batch] + full_texts = [d["labels"] + self.tokenizer.eos_token for d in batch] + + inputs = self.tokenizer( + text = full_texts, + text_target = input_texts, + return_tensors="pt", + padding="longest", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_attention_mask=True, + ) + labels = copy.deepcopy(inputs["input_ids"]) + if self.only_train_response: + # ignore padding + labels[labels == self.tokenizer.pad_token_id] = -100 + # ignore input text + labels[torch.where(inputs["labels"] != self.tokenizer.pad_token_id)] = -100 + + inputs["labels"] = labels + + return inputs + +class TestCollator(object): + def __init__(self, args, tokenizer): + self.args = args + self.tokenizer = tokenizer + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token_id = 0 + + if isinstance(self.tokenizer, LlamaTokenizer): + self.tokenizer.padding_side = "left" + + def __call__(self, batch): + input_texts = [d["input_ids"] for d in batch] + targets = [d["labels"] for d in batch] + inputs = self.tokenizer( + text = input_texts, + return_tensors ="pt", + padding = "longest", + max_length = self.tokenizer.model_max_length, + truncation = True, + return_attention_mask = True, + ) + + return (inputs, targets) + +# RuntimeError: Cannot re-initialize CUDA in forked subprocess. +# To use CUDA with multiprocessing, you must use the 'spawn' start method. +# class ValidCollator(object): +# def __init__(self, args, model): +# self.args = args +# self.model = model +# self.only_train_response = args.only_train_response +# self.tokenizer = model.tokenizer +# def __call__(self, data): +# llama_model = self.model.model.get_decoder() +# for d in data: +# inter_emb_list = [] +# inter_item_list = d['inters'].split(',') +# for inter_item in inter_item_list: +# inter_feature = self.model.item_texts[inter_item]['title'] + ' ' + self.model.item_texts[inter_item]['description'] +# inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) +# inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask) +# inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1) +# inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True) +# inter_emb_list.append(inter_emb.detach()) +# inter_embs = torch.cat(inter_emb_list, dim = 0) +# item_feature = self.model.item_texts[d['item']]['title'] + ' ' + self.model.item_texts[d['item']]['description'] +# item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) +# item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask) +# item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1) +# item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True) +# item_emb = item_emb.detach() + +# rqids = self.model.rqvae.get_indices(torch.cat([inter_embs, item_emb], dim = 0)) + +# inters_rqids = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[:-1] +# item_rqid = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[-1] + +# text_rqids = {} +# code = '' +# for rqid in inters_rqids: +# for k, idx in enumerate(rqid): +# code = code + self.model.prefix[k].format(idx) +# code = code + ', ' +# text_rqids['inters'] = code[:-2] +# code = '' +# for k, idx in enumerate(item_rqid): +# code = code + self.model.prefix[k].format(idx) +# text_rqids['item'] = code + +# d['input_ids'] = d['input_ids'].format(inters = text_rqids['inters']) +# d['labels'] = d['labels'].format(inters = text_rqids['inters'], item = text_rqids['item']) + +# input_texts = [d["input_ids"] for d in data] +# full_texts = [d["labels"] + self.tokenizer.eos_token for d in data] + +# inputs = self.tokenizer( +# text = full_texts, +# text_target = input_texts, +# return_tensors="pt", +# padding="longest", +# max_length=self.tokenizer.model_max_length, +# truncation=True, +# return_attention_mask=True, +# ) + +# labels = copy.deepcopy(inputs["input_ids"]) +# if self.only_train_response: +# labels[labels == self.tokenizer.pad_token_id] = -100 +# labels[torch.where(inputs["labels"] != self.tokenizer.pad_token_id)] = -100 +# inputs["labels"] = labels + +# return inputs + +# RuntimeError: Cannot re-initialize CUDA in forked subprocess. +# To use CUDA with multiprocessing, you must use the 'spawn' start method. +# class TestCollator(object): +# def __init__(self, args, model): +# self.args = args +# self.model = model +# self.tokenizer = model.tokenizer +# if self.tokenizer.pad_token_id is None: +# self.tokenizer.pad_token_id = 0 +# if isinstance(self.tokenizer, LlamaTokenizer): +# self.tokenizer.padding_side = "left" + +# def __call__(self, data): +# llama_model = self.model.model.get_decoder() +# for d in data: +# inter_emb_list = [] +# inter_item_list = d['inters'].split(',') +# for inter_item in inter_item_list: +# inter_feature = self.model.item_texts[inter_item]['title'] + ' ' + self.model.item_texts[inter_item]['description'] +# inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) +# inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask) +# inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1) +# inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True) +# inter_emb_list.append(inter_emb.detach()) +# inter_embs = torch.cat(inter_emb_list, dim = 0) +# item_feature = self.model.item_texts[d['item']]['title'] + ' ' + self.model.item_texts[d['item']]['description'] +# item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) +# item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask) +# item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1) +# item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True) +# item_emb = item_emb.detach() + +# rqids = self.model.rqvae.get_indices(torch.cat([inter_embs, item_emb], dim = 0)) + +# inters_rqids = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[:-1] +# item_rqid = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[-1] + +# text_rqids = {} +# code = '' +# for rqid in inters_rqids: +# for k, idx in enumerate(rqid): +# code = code + self.model.prefix[k].format(idx) +# code = code + ', ' +# text_rqids['inters'] = code[:-2] +# code = '' +# for k, idx in enumerate(item_rqid): +# code = code + self.model.prefix[k].format(idx) +# text_rqids['item'] = code + +# d['input_ids'] = d['input_ids'].format(inters = text_rqids['inters']) +# d['labels'] = d['labels'].format(inters = text_rqids['inters'], item = text_rqids['item']) + +# input_texts = [d["input_ids"] for d in data] +# targets = [d["labels"] for d in data] + +# inputs = self.tokenizer( +# text=input_texts, +# return_tensors="pt", +# padding="longest", +# max_length=self.tokenizer.model_max_length, +# truncation=True, +# return_attention_mask=True, +# ) + +# return (inputs, targets) \ No newline at end of file diff --git a/config/ds_z2_bf16.json b/config/ds_z2_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..838e00e2b903ab73bd9fa2cc4ce436a84d944c23 --- /dev/null +++ b/config/ds_z2_bf16.json @@ -0,0 +1,28 @@ +{ + "bf16": { + "enabled": "auto" + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false, + "flops_profiler": { + "enabled": true, + "profile_step": 10, + "module_depth": -1, + "top_modules": 3, + "detailed": true, + "output_file": "flops_profiler.out" + } +} \ No newline at end of file diff --git a/config/ds_z2_fp16.json b/config/ds_z2_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..d15699e50636bf2536592312c37162c070d586df --- /dev/null +++ b/config/ds_z2_fp16.json @@ -0,0 +1,34 @@ +{ + "fp16": { + "enabled": "auto", + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 2, + "allgather_partitions": true, + "allgather_bucket_size": 5e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 5e8, + "contiguous_gradients": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false, + "flops_profiler": { + "enabled": true, + "profile_step": 10, + "module_depth": -1, + "top_modules": 3, + "detailed": true, + "output_file": "flops_profiler.out" + } +} \ No newline at end of file diff --git a/config/ds_z3_bf16.json b/config/ds_z3_bf16.json new file mode 100644 index 0000000000000000000000000000000000000000..0f4ca4a761aa64ec0d2953ce828c83704a1a77a4 --- /dev/null +++ b/config/ds_z3_bf16.json @@ -0,0 +1,31 @@ +{ + "bf16": { + "enabled": "auto" + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": false + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false, + "flops_profiler": { + "enabled": true, + "profile_step": 10, + "module_depth": -1, + "top_modules": 3, + "detailed": true, + "output_file": "flops_profiler.out" + } +} \ No newline at end of file diff --git a/config/ds_z3_bf16_save16bit.json b/config/ds_z3_bf16_save16bit.json new file mode 100644 index 0000000000000000000000000000000000000000..fc109cbd6b6531b5cb57cab2f121e1cfda645ba4 --- /dev/null +++ b/config/ds_z3_bf16_save16bit.json @@ -0,0 +1,31 @@ +{ + "bf16": { + "enabled": "auto" + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false, + "flops_profiler": { + "enabled": true, + "profile_step": 10, + "module_depth": -1, + "top_modules": 3, + "detailed": true, + "output_file": "flops_profiler.out" + } +} \ No newline at end of file diff --git a/config/ds_z3_fp16.json b/config/ds_z3_fp16.json new file mode 100644 index 0000000000000000000000000000000000000000..dbb387fccc9d72cdc3627226fb2a430e1d15418b --- /dev/null +++ b/config/ds_z3_fp16.json @@ -0,0 +1,37 @@ +{ + "fp16": { + "enabled": "auto", + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": false + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false, + "flops_profiler": { + "enabled": true, + "profile_step": 10, + "module_depth": -1, + "top_modules": 3, + "detailed": true, + "output_file": "flops_profiler.out" + } +} \ No newline at end of file diff --git a/config/ds_z3_fp16_save16bit.json b/config/ds_z3_fp16_save16bit.json new file mode 100644 index 0000000000000000000000000000000000000000..00a09986aeeaec23f81ad14c00cebe660d330b62 --- /dev/null +++ b/config/ds_z3_fp16_save16bit.json @@ -0,0 +1,37 @@ +{ + "fp16": { + "enabled": "auto", + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 16, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 3, + "overlap_comm": true, + "contiguous_gradients": true, + "sub_group_size": 1e9, + "reduce_bucket_size": "auto", + "stage3_prefetch_bucket_size": "auto", + "stage3_param_persistence_threshold": "auto", + "stage3_max_live_parameters": 1e9, + "stage3_max_reuse_distance": 1e9, + "stage3_gather_16bit_weights_on_model_save": true + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "steps_per_print": 2000, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false, + "flops_profiler": { + "enabled": true, + "profile_step": 10, + "module_depth": -1, + "top_modules": 3, + "detailed": true, + "output_file": "flops_profiler.out" + } +} \ No newline at end of file diff --git a/continue_pretrain.py b/continue_pretrain.py new file mode 100644 index 0000000000000000000000000000000000000000..18781b7379f7ee03a8a765f693b0da655c7cfa02 --- /dev/null +++ b/continue_pretrain.py @@ -0,0 +1,125 @@ +import os +import sys +from typing import List +import argparse + +import wandb +import torch +import transformers +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig + +from peft import ( + TaskType, + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + set_peft_model_state_dict, +) + +from collator import VanillaCollator +from rq_llama import * +from utils import * + +parser = argparse.ArgumentParser(description = 'rqllama-pretrain') +parser = parse_global_args(parser) +parser = parse_train_args(parser) +parser = parse_dataset_args(parser) +parser = parse_rqvae_args(parser) +args = parser.parse_args() +wandb.init(config = args, reinit = True) + +set_seed(args.seed) +ensure_dir(args.output_dir) + +device_map = "auto" +world_size = int(os.environ.get("WORLD_SIZE", 1)) +ddp = world_size != 1 +local_rank = int(os.environ.get("LOCAL_RANK") or 0) +if local_rank == 0: + print(vars(args)) +if ddp: + device_map = {"": local_rank} + +train_data, valid_data = load_datasets(args) + +rqllama = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map) +for i in range(len(args.num_emb_list)): + rqllama.item_rqvae.rq.vq_layers[i].initted = True + rqllama.user_rqvae.rq.vq_layers[i].initted = True + +if local_rank == 0: + print("token num:", len(rqllama.tokenizer)) + print("data num:", len(train_data)) + rqllama.tokenizer.save_pretrained(args.output_dir) + rqllama.config.save_pretrained(args.output_dir) + +if args.resume_from_checkpoint: + checkpoint_name = os.path.join(args.resume_from_checkpoint, "adapter_model.bin") + args.resume_from_checkpoint = False + if os.path.exists(checkpoint_name): + if local_rank == 0: + print(f"Restarting from {checkpoint_name}") + adapters_weights = torch.load(checkpoint_name) + rqllama.model = set_peft_model_state_dict(rqllama.model, adapters_weights) + else: + if local_rank == 0: + print(f"Checkpoint {checkpoint_name} not found") + +if local_rank == 0: + rqllama.model.print_trainable_parameters() + +if not ddp and torch.cuda.device_count() > 1: + rqllama.is_parallelizable = True + rqllama.model_parallel = True + +collator = VanillaCollator(args, rqllama.tokenizer) + +trainer = transformers.Trainer( + model = rqllama, + train_dataset = train_data, + eval_dataset = valid_data, + args = transformers.TrainingArguments( + seed = args.seed, + per_device_train_batch_size = args.per_device_batch_size, + per_device_eval_batch_size = args.per_device_batch_size, + gradient_accumulation_steps = args.gradient_accumulation_steps, + warmup_ratio = args.warmup_ratio, + num_train_epochs = args.epochs, + learning_rate = args.learning_rate, + weight_decay = args.weight_decay, + lr_scheduler_type = args.lr_scheduler_type, + fp16 = args.fp16, + bf16 = args.bf16, + logging_steps = args.logging_step, + optim = args.optim, + gradient_checkpointing = True, + evaluation_strategy = args.save_and_eval_strategy, + save_strategy = args.save_and_eval_strategy, + eval_steps = args.save_and_eval_steps, + save_steps = args.save_and_eval_steps, + output_dir = args.output_dir, + save_total_limit = 5, + load_best_model_at_end = True, + deepspeed = args.deepspeed, + ddp_find_unused_parameters = False if ddp else None, + report_to = None, + eval_delay = 1 if args.save_and_eval_strategy=="epoch" else 2000, + dataloader_num_workers = args.dataloader_num_workers, + dataloader_prefetch_factor = args.dataloader_prefetch_factor, + remove_unused_columns = args.remove_unused_columns, + ), + tokenizer = rqllama.tokenizer, + data_collator = collator, +) +rqllama.config.use_cache = False + +if torch.__version__ >= "2" and sys.platform != "win32": + rqllama = torch.compile(rqllama) + +trainer.train(resume_from_checkpoint = args.resume_from_checkpoint) + +trainer.save_state() +trainer.save_model(output_dir = args.output_dir) + +if local_rank == 0: + print('rqllama pre-train finished.') \ No newline at end of file diff --git a/convert/convert.py b/convert/convert.py new file mode 100644 index 0000000000000000000000000000000000000000..f23f257f0c26116db8a0480d84f9d0ac5207f068 --- /dev/null +++ b/convert/convert.py @@ -0,0 +1,16 @@ +import transformers +import argparse +import os + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--source", "-s", type=str, default="", help="source path of models") + parser.add_argument("--target", "-t", type=str, default="", help="target path of models") + + args, _ = parser.parse_known_args() + + assert os.path.exists(args.source) + assert args.target != "" + + model = transformers.AutoModelForCausalLM.from_pretrained(args.source) + model.save_pretrained(args.target, state_dict=model.state_dict()) \ No newline at end of file diff --git a/convert/convert.sh b/convert/convert.sh new file mode 100644 index 0000000000000000000000000000000000000000..da33cc755b1a61b8d811119fe969f93c97535d4d --- /dev/null +++ b/convert/convert.sh @@ -0,0 +1,18 @@ +model=$1 + +set -x + +for step in `ls ${model} | grep checkpoint | awk -F'-' '{ print $2 }'` +do +mkdir ${model}/tmp-checkpoint-${step} +mkdir ${model}/final-checkpoint-${step} +python ./zero_to_fp32.py ${model}/checkpoint-${step}/ ${model}/tmp-checkpoint-${step}/pytorch_model.bin +cp ${model}/*.json ${model}/tmp-checkpoint-${step} +python ./convert.py -s ${model}/tmp-checkpoint-${step} -t ${model}/final-checkpoint-${step} +cp ${model}/checkpoint-${step}/*.json ${model}/final-checkpoint-${step} +cp ${model}/*.json ${model}/final-checkpoint-${step} +cp ${model}/tokenizer* ${model}/final-checkpoint-${step} +cp ${model}/train* ${model}/final-checkpoint-${step} +#rm -rf ${model}/tmp-checkpoint-${step} ${model}/checkpoint-${step} ${model}/global_step${step} +#mv ${model}/final-checkpoint-${step} ${model}/checkpoint-${step} +done \ No newline at end of file diff --git a/convert/convert_fp16.py b/convert/convert_fp16.py new file mode 100644 index 0000000000000000000000000000000000000000..e23216d65b1b937d1c823e030df0ab569aec989a --- /dev/null +++ b/convert/convert_fp16.py @@ -0,0 +1,23 @@ + +import argparse + +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch + + +def convert_fp16(in_checkpoint, out_checkpoint): + tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False) + model = AutoModelForCausalLM.from_pretrained( + in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + model.save_pretrained(out_checkpoint) + tokenizer.save_pretrained(out_checkpoint) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--in-checkpoint", type=str, help="Path to the model") + parser.add_argument("--out-checkpoint", type=str, help="Path to the output model") + args = parser.parse_args() + + convert_fp16(args.in_checkpoint, args.out_checkpoint) diff --git a/convert/make_delta.py b/convert/make_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..69fcf88b267e60fc208f97964324a4e41b39d239 --- /dev/null +++ b/convert/make_delta.py @@ -0,0 +1,46 @@ + +import argparse + +import torch +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM + + +def make_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the target model from {target_model_path}") + target = AutoModelForCausalLM.from_pretrained( + target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False) + + print("Calculating the delta") + for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"): + assert name in base.state_dict() + if param.shape == base.state_dict()[name].shape: + param.data -= base.state_dict()[name] + else: + print(name) + + print(f"Saving the delta to {delta_path}") + if args.hub_repo_id: + kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id} + else: + kwargs = {} + target.save_pretrained(delta_path, **kwargs) + target_tokenizer.save_pretrained(delta_path, **kwargs) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument("--hub-repo-id", type=str) + args = parser.parse_args() + + make_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/convert/merge_delta.py b/convert/merge_delta.py new file mode 100644 index 0000000000000000000000000000000000000000..dabbe3a3c77ee75e0ec382caa38c64289febef0e --- /dev/null +++ b/convert/merge_delta.py @@ -0,0 +1,167 @@ + +import argparse +import gc +import glob +import json +import os +import shutil +import tempfile + +from huggingface_hub import snapshot_download +import torch +from torch import nn +from tqdm import tqdm +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig + + +GB = 1 << 30 + + +def split_files(model_path, tmp_path, split_size): + if not os.path.exists(model_path): + model_path = snapshot_download(repo_id=model_path) + if not os.path.exists(tmp_path): + os.makedirs(tmp_path) + + file_pattern = os.path.join(model_path, "pytorch_model-*.bin") + files = glob.glob(file_pattern) + + part = 0 + try: + for file_path in tqdm(files): + state_dict = torch.load(file_path) + new_state_dict = {} + + current_size = 0 + for name, param in state_dict.items(): + param_size = param.numel() * param.element_size() + + if current_size + param_size > split_size: + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + current_size = 0 + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + + new_state_dict[name] = param + current_size += param_size + + new_file_name = f"pytorch_model-{part}.bin" + new_file_path = os.path.join(tmp_path, new_file_name) + torch.save(new_state_dict, new_file_path) + new_state_dict = None + gc.collect() + new_state_dict = {} + part += 1 + except Exception as e: + print(f"An error occurred during split_files: {e}") + shutil.rmtree(tmp_path) + raise + + +def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path): + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta_config = AutoConfig.from_pretrained(delta_path) + + if os.path.exists(target_model_path): + shutil.rmtree(target_model_path) + os.makedirs(target_model_path) + + split_size = 4 * GB + + with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path: + print(f"Split files for the base model to {tmp_base_path}") + split_files(base_model_path, tmp_base_path, split_size) + print(f"Split files for the delta weights to {tmp_delta_path}") + split_files(delta_path, tmp_delta_path, split_size) + + base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin") + base_files = glob.glob(base_pattern) + base_state_dict = torch.load(base_files[0]) + delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin") + delta_files = glob.glob(delta_pattern) + # delta_state_dict = torch.load(delta_files[0]) + + print("Applying the delta") + weight_map = {} + total_size = 0 + + for i, delta_file in tqdm(enumerate(delta_files)): + state_dict = torch.load(delta_file) + file_name = f"pytorch_model-{i}.bin" + for name, param in state_dict.items(): + if name not in base_state_dict: + for base_file in base_files: + base_state_dict = torch.load(base_file) + gc.collect() + if name in base_state_dict: + break + if state_dict[name].shape == base_state_dict[name].shape: + state_dict[name] += base_state_dict[name] + else: + print(name) + weight_map[name] = file_name + total_size += param.numel() * param.element_size() + gc.collect() + torch.save(state_dict, os.path.join(target_model_path, file_name)) + + with open( + os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w" + ) as f: + json.dump( + {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f + ) + + print(f"Saving the target model to {target_model_path}") + delta_tokenizer.save_pretrained(target_model_path) + delta_config.save_pretrained(target_model_path) + + +def apply_delta(base_model_path, target_model_path, delta_path): + print(f"Loading the delta weights from {delta_path}") + delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False) + delta = AutoModelForCausalLM.from_pretrained( + delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print(f"Loading the base model from {base_model_path}") + base = AutoModelForCausalLM.from_pretrained( + base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True + ) + + print("Applying the delta") + for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"): + assert name in base.state_dict() + if param.shape == base.state_dict()[name].shape: + param.data += base.state_dict()[name] + else: + print(name) + + + print(f"Saving the target model to {target_model_path}") + delta.save_pretrained(target_model_path) + delta_tokenizer.save_pretrained(target_model_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--base-model-path", type=str, required=True) + parser.add_argument("--target-model-path", type=str, required=True) + parser.add_argument("--delta-path", type=str, required=True) + parser.add_argument( + "--low-cpu-mem", + action="store_true", + help="Lower the cpu memory usage. This will split large files and use " + "disk as swap to reduce the memory usage below 10GB.", + ) + args = parser.parse_args() + + if args.low_cpu_mem: + apply_delta_low_cpu_mem( + args.base_model_path, args.target_model_path, args.delta_path + ) + else: + apply_delta(args.base_model_path, args.target_model_path, args.delta_path) diff --git a/convert/zero_to_fp32.py b/convert/zero_to_fp32.py new file mode 100644 index 0000000000000000000000000000000000000000..12f6efbbd4e8530ef997c08d95a6a4460039c3ce --- /dev/null +++ b/convert/zero_to_fp32.py @@ -0,0 +1,600 @@ +#!/usr/bin/env python + +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets +# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in +# the future. Once extracted, the weights don't require DeepSpeed and can be used in any +# application. +# +# example: python zero_to_fp32.py . pytorch_model.bin + +import argparse +import torch +import glob +import math +import os +import re +from collections import OrderedDict +from dataclasses import dataclass +from tqdm import tqdm + +# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with +# DeepSpeed data structures it has to be available in the current python environment. +from deepspeed.utils import logger +from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS, + FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES, + FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS) + + +@dataclass +class zero_model_state: + buffers: dict() + param_shapes: dict() + shared_params: list + ds_version: int + frozen_param_shapes: dict() + frozen_param_fragments: dict() + + +debug = 0 + +# load to cpu +device = torch.device('cpu') + + +def atoi(text): + return int(text) if text.isdigit() else text + + +def natural_keys(text): + ''' + alist.sort(key=natural_keys) sorts in human order + http://nedbatchelder.com/blog/200712/human_sorting.html + (See Toothy's implementation in the comments) + ''' + return [atoi(c) for c in re.split(r'(\d+)', text)] + + +def get_model_state_file(checkpoint_dir, zero_stage): + if not os.path.isdir(checkpoint_dir): + raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist") + + # there should be only one file + if zero_stage == 2: + file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt") + elif zero_stage == 3: + file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt") + + if not os.path.exists(file): + raise FileNotFoundError(f"can't find model states file at '{file}'") + + return file + + +def get_checkpoint_files(checkpoint_dir, glob_pattern): + # XXX: need to test that this simple glob rule works for multi-node setup too + ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys) + + if len(ckpt_files) == 0: + raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'") + + return ckpt_files + + +def get_optim_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt") + + +def get_model_state_files(checkpoint_dir): + return get_checkpoint_files(checkpoint_dir, "*_model_states.pt") + + +def parse_model_states(files): + zero_model_states = [] + for file in files: + state_dict = torch.load(file, map_location=device) + + if BUFFER_NAMES not in state_dict: + raise ValueError(f"{file} is not a model state checkpoint") + buffer_names = state_dict[BUFFER_NAMES] + if debug: + print("Found buffers:", buffer_names) + + # recover just the buffers while restoring them to fp32 if they were saved in fp16 + buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names} + param_shapes = state_dict[PARAM_SHAPES] + + # collect parameters that are included in param_shapes + param_names = [] + for s in param_shapes: + for name in s.keys(): + param_names.append(name) + + # update with frozen parameters + frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None) + if frozen_param_shapes is not None: + if debug: + print(f"Found frozen_param_shapes: {frozen_param_shapes}") + param_names += list(frozen_param_shapes.keys()) + + # record shared parameters so that they can be recovered based on partners + # this is because such parameters holding reference only are not saved by optimizer + shared_params = [] + for param in state_dict["module"]: + if param not in [*param_names, *buffer_names]: + for share_param in state_dict["module"]: + if (state_dict["module"][share_param].data_ptr() == state_dict["module"][param].data_ptr() + and share_param != param): + shared_params.append([param, share_param]) + break + + ds_version = state_dict.get(DS_VERSION, None) + + frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None) + + z_model_state = zero_model_state(buffers=buffers, + param_shapes=param_shapes, + shared_params=shared_params, + ds_version=ds_version, + frozen_param_shapes=frozen_param_shapes, + frozen_param_fragments=frozen_param_fragments) + zero_model_states.append(z_model_state) + + return zero_model_states + + +def parse_optim_states(files, ds_checkpoint_dir): + + total_files = len(files) + state_dicts = [] + for i, f in enumerate(tqdm(files)): + state_dicts.append(torch.load(f, map_location=device)) + if i == 0: + if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]: + raise ValueError(f"{files[0]} is not a zero checkpoint") + zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE] + world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT] + + # For ZeRO-2 each param group can have different partition_count as data parallelism for expert + # parameters can be different from data parallelism for non-expert parameters. So we can just + # use the max of the partition_count to get the dp world_size. + + if type(world_size) is list: + world_size = max(world_size) + + if world_size != total_files: + raise ValueError( + f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. " + "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes." + ) + + # the groups are named differently in each stage + if zero_stage == 2: + fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS + elif zero_stage == 3: + fp32_groups_key = FP32_FLAT_GROUPS + else: + raise ValueError(f"unknown zero stage {zero_stage}") + + key_list = list(state_dicts[-1][OPTIMIZER_STATE_DICT].keys()) + for key in key_list: + if zero_stage == 2: + if key != fp32_groups_key: + del state_dicts[-1][OPTIMIZER_STATE_DICT][key] + elif zero_stage == 3: + if key == fp32_groups_key: + value = torch.cat(state_dicts[-1][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) + del state_dicts[-1][OPTIMIZER_STATE_DICT][key] + if key == fp32_groups_key: + state_dicts[-1][OPTIMIZER_STATE_DICT][key] = value + + print('zero_stage:', zero_stage) + fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + # if zero_stage == 2: + # # fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))] + # elif zero_stage == 3: + # # if there is more than one param group, there will be multiple flattened tensors - one + # # flattened tensor per group - for simplicity merge them into a single tensor + # # + # # XXX: could make the script more memory efficient for when there are multiple groups - it + # # will require matching the sub-lists of param_shapes for each param group flattened tensor + + # print('start!') + # # fp32_flat_groups = [ + # # torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts)) + # # ] + + return zero_stage, world_size, fp32_flat_groups + + +def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir): + """ + Returns fp32 state_dict reconstructed from ds checkpoint + + Args: + - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are) + + """ + print(f"Processing zero checkpoint '{ds_checkpoint_dir}'") + + optim_files = get_optim_files(ds_checkpoint_dir) + zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir) + print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}") + + model_files = get_model_state_files(ds_checkpoint_dir) + + zero_model_states = parse_model_states(model_files) + print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}') + + if zero_stage == 2: + return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states) + elif zero_stage == 3: + return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states) + + +def _zero2_merge_frozen_params(state_dict, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + frozen_param_fragments = zero_model_states[0].frozen_param_fragments + + if debug: + num_elem = sum(s.numel() for s in frozen_param_shapes.values()) + print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in frozen_param_fragments.values()]) + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in frozen_param_shapes.items(): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + state_dict[name] = frozen_param_fragments[name] + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + + # Reconstruction protocol: + # + # XXX: document this + + if debug: + for i in range(world_size): + for j in range(len(fp32_flat_groups[0])): + print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}") + + # XXX: memory usage doubles here (zero2) + num_param_groups = len(fp32_flat_groups[0]) + merged_single_partition_of_fp32_groups = [] + for i in range(num_param_groups): + merged_partitions = [sd[i] for sd in fp32_flat_groups] + full_single_fp32_vector = torch.cat(merged_partitions, 0) + merged_single_partition_of_fp32_groups.append(full_single_fp32_vector) + avail_numel = sum( + [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups]) + + if debug: + wanted_params = sum([len(shapes) for shapes in param_shapes]) + wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes]) + # not asserting if there is a mismatch due to possible padding + print(f"Have {avail_numel} numels to process.") + print(f"Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + total_numel = 0 + total_params = 0 + for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups): + offset = 0 + avail_numel = full_single_fp32_vector.numel() + for name, shape in shapes.items(): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + if debug: + print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ") + state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape) + offset += unpartitioned_numel + + # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and + # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex + # paddings performed in the code it's almost impossible to predict the exact numbers w/o the + # live optimizer object, so we are checking that the numbers are within the right range + align_to = 2 * world_size + + def zero2_align(x): + return align_to * math.ceil(x / align_to) + + if debug: + print(f"original offset={offset}, avail_numel={avail_numel}") + + offset = zero2_align(offset) + avail_numel = zero2_align(avail_numel) + + if debug: + print(f"aligned offset={offset}, avail_numel={avail_numel}") + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + _zero2_merge_frozen_params(state_dict, zero_model_states) + + _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def zero3_partitioned_param_info(unpartitioned_numel, world_size): + remainder = unpartitioned_numel % world_size + padding_numel = (world_size - remainder) if remainder else 0 + partitioned_numel = math.ceil(unpartitioned_numel / world_size) + return partitioned_numel, padding_numel + + +def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states): + if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0: + return + + if debug: + for i in range(world_size): + num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values()) + print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}') + + frozen_param_shapes = zero_model_states[0].frozen_param_shapes + wanted_params = len(frozen_param_shapes) + wanted_numel = sum(s.numel() for s in frozen_param_shapes.values()) + avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size + print(f'Frozen params: Have {avail_numel} numels to process.') + print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params') + + total_params = 0 + total_numel = 0 + for name, shape in tqdm(zero_model_states[0].frozen_param_shapes.items()): + total_params += 1 + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + + param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states) + state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape) + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements") + + +def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states): + param_shapes = zero_model_states[0].param_shapes + avail_numel = fp32_flat_groups[0].numel() * world_size + # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each + # param, re-consolidating each param, while dealing with padding if any + + # merge list of dicts, preserving order + param_shapes = {k: v for d in param_shapes for k, v in d.items()} + + if debug: + for i in range(world_size): + print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}") + + wanted_params = len(param_shapes) + wanted_numel = sum(shape.numel() for shape in param_shapes.values()) + # not asserting if there is a mismatch due to possible padding + avail_numel = fp32_flat_groups[0].numel() * world_size + print(f"Trainable params: Have {avail_numel} numels to process.") + print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.") + + # params + # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support + # out-of-core computing solution + offset = 0 + total_numel = 0 + total_params = 0 + for name, shape in tqdm(param_shapes.items()): + + unpartitioned_numel = shape.numel() + total_numel += unpartitioned_numel + total_params += 1 + + partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size) + + if debug: + print( + f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}" + ) + + # XXX: memory usage doubles here + state_dict[name] = torch.cat( + tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)), + 0).narrow(0, 0, unpartitioned_numel).view(shape) + offset += partitioned_numel + + offset *= world_size + + # Sanity check + if offset != avail_numel: + raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong") + + print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements") + + +def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states): + state_dict = OrderedDict() + + # buffers + buffers = zero_model_states[0].buffers + state_dict.update(buffers) + if debug: + print(f"added {len(buffers)} buffers") + + _zero3_merge_frozen_params(state_dict, world_size, zero_model_states) + + _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states) + + # recover shared parameters + for pair in zero_model_states[0].shared_params: + state_dict[pair[0]] = state_dict[pair[1]] + + return state_dict + + +def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with + ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example + via a model hub. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14`` + + Returns: + - pytorch ``state_dict`` + + Note: this approach may not work if your application doesn't have sufficient free CPU memory and + you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with + the checkpoint. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint + # do the training and checkpoint saving + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu + model = model.cpu() # move to cpu + model.load_state_dict(state_dict) + # submit to model hub or save the model to share with others + + In this example the ``model`` will no longer be usable in the deepspeed context of the same + application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead. + + """ + if tag is None: + latest_path = os.path.join(checkpoint_dir, 'latest') + if os.path.isfile(latest_path): + with open(latest_path, 'r') as fd: + tag = fd.read().strip() + else: + raise ValueError(f"Unable to find 'latest' file at {latest_path}") + + ds_checkpoint_dir = os.path.join(checkpoint_dir, tag) + + if not os.path.isdir(ds_checkpoint_dir): + raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist") + + return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir) + + +def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None): + """ + Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be + loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed. + + Args: + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + """ + + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + print(f"Saving fp32 state dict to {output_file}") + torch.save(state_dict, output_file) + + +def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None): + """ + 1. Put the provided model to cpu + 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` + 3. Load it into the provided model + + Args: + - ``model``: the model object to update + - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``) + - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14`` + + Returns: + - ``model`: modified model + + Make sure you have plenty of CPU memory available before you call this function. If you don't + have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it + conveniently placed for you in the checkpoint folder. + + A typical usage might be :: + + from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint + model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir) + # submit to model hub or save the model to share with others + + Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context + of the same application. i.e. you will need to re-initialize the deepspeed engine, since + ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it. + + """ + logger.info(f"Extracting fp32 weights") + state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag) + + logger.info(f"Overwriting model with fp32 weights") + model = model.cpu() + model.load_state_dict(state_dict, strict=False) + + return model + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint_dir", + type=str, + help="path to the desired checkpoint folder, e.g., path/checkpoint-12") + parser.add_argument( + "output_file", + type=str, + help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)") + parser.add_argument("-d", "--debug", action='store_true', help="enable debug") + args = parser.parse_args() + + debug = args.debug + + convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file) diff --git a/data.py b/data.py new file mode 100644 index 0000000000000000000000000000000000000000..2047784fcc86c46f58b0226908e46b1e5c25cf3b --- /dev/null +++ b/data.py @@ -0,0 +1,1035 @@ +import copy +import random +import argparse +import os +import torch +import torch.nn as nn +from torch.utils.data import Dataset +from tqdm import tqdm +from collections import defaultdict +import torch.distributed as dist +import logging +import re +import pdb +import json +from prompt import sft_prompt, all_prompt +import numpy as np + +class BaseDataset(Dataset): + def __init__(self, args): + super().__init__() + + self.args = args + self.dataset = args.dataset + self.data_path = os.path.join(args.data_path, self.dataset) + + self.max_his_len = args.max_his_len + self.his_sep = args.his_sep + self.index_file = args.index_file + self.user_index_file = args.user_index_file + self.add_prefix = args.add_prefix + + self.new_tokens = None + self.allowed_tokens = None + self.all_items = None + + def _load_data(self): + with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f: + self.indices = json.load(f) + + def get_new_tokens(self): + if self.new_tokens is not None: + return self.new_tokens + + self.new_tokens = set() + for index in self.indices.values(): + for token in index: + self.new_tokens.add(token) + self.new_tokens = sorted(list(self.new_tokens)) + + return self.new_tokens + + def get_all_items(self): + if self.all_items is not None: + return self.all_items + + self.all_items = set() + for index in self.indices.values(): + self.all_items.add("".join(index)) + + return self.all_items + + def get_prefix_allowed_tokens_fn(self, tokenizer): + if self.allowed_tokens is None: + self.allowed_tokens = {} + for index in self.indices.values(): + for i, token in enumerate(index): + token_id = tokenizer(token)["input_ids"][1] + if i not in self.allowed_tokens.keys(): + self.allowed_tokens[i] = set() + self.allowed_tokens[i].add(token_id) + self.allowed_tokens[len(self.allowed_tokens.keys())] = set([tokenizer.eos_token_id]) + sep = tokenizer("Response:")["input_ids"][1:] + + def prefix_allowed_tokens_fn(batch_id, sentence): + sentence = sentence.tolist() + reversed_sent = sentence[::-1] + for i in range(len(reversed_sent)): + if reversed_sent[i:i + len(sep)] == sep[::-1]: + return list(self.allowed_tokens[i]) + + return prefix_allowed_tokens_fn + + def _process_data(self): + raise NotImplementedError + +class UserFeatDataset(BaseDataset): + def __init__(self, args, task = "pref2user", prompt_sample_num = 1, sample_num = -1): + super().__init__(args) + + self.task = task.lower() + self.prompt_sample_num = prompt_sample_num + self.sample_num = sample_num + + self.prompts = all_prompt[self.task] + + self._load_data() + self.feat_data = self._process_data() + + def _load_data(self): + with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: + user_feat = json.load(f) + # >>> user_feat.keys() + # dict_keys(['user_explicit_preference', 'user_vague_intention']) + self.user_feat = user_feat['user_explicit_preference'] + # >>> user_feat['0'] + # ['The user is a passionate musician who enjoys exploring different types of musical instruments.'] + # >>> len(user_feat) + # 24772 + + def _process_data(self): + feat_data = [] + for uid in self.user_feat: + one_data = {} + one_data['user'] = uid + + preference = " ".join(self.user_feat[uid]) + preference = preference.strip().strip(".!?,;:`") + preference = preference.replace('{','').replace('}','') + one_data['preference'] = preference + + feat_data.append(one_data) + + if self.sample_num > 0: + all_idx = range(len(feat_data)) + sample_idx = np.random.choice(all_idx, self.sample_num, replace = False) + feat_data = np.array(feat_data)[sample_idx].tolist() + + return feat_data + + def __len__(self): + return len(self.feat_data) * self.prompt_sample_num + + def __getitem__(self, index): + idx = index // self.prompt_sample_num + d = self.feat_data[idx] + prompt_id = random.randint(0, len(self.prompts) - 1) + prompt = self.prompts[prompt_id] + + if self.task == 'pref2user': + instruction = prompt['instruction'].format(preference = d['preference']) + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = prompt["response"]) + return dict( + input_ids = input, + labels = output, + inters = 'placeholder', + item = 'placeholder', + users = 'placeholder', + user = d['user'], + task = self.task + ) + elif self.task == 'user2pref': + input = sft_prompt.format(instruction = prompt["instruction"], response = "") + response = prompt["response"].format(preference = d['preference']) + output = sft_prompt.format(instruction = prompt["instruction"], response = response) + return dict( + input_ids = input, + labels = output, + inters = 'placeholder', + item = 'placeholder', + users = 'placeholder', + user = d['user'], + task = self.task + ) + else: + raise NotImplementedError + +class UserSearchDataset(BaseDataset): + def __init__(self, args, prompt_sample_num = 1, prompt_id = 0, sample_num = -1): + super().__init__(args) + + self.prompt_sample_num = prompt_sample_num + self.prompt_id = prompt_id + self.sample_num = sample_num + + self.prompts = all_prompt["usersearch"] + + self._load_data() + self.search_data = self._process_data() + + def _load_data(self): + with open(os.path.join(self.data_path, self.dataset + ".inter.user.json"), 'r') as f: + self.user_inters = json.load(f) + + def _process_data(self): + search_data = [] + for iid in self.user_inters.keys(): + users = self.user_inters[iid] + for i in range(1, len(users)): + one_data = {} + one_data['item'] = iid + one_data['user'] = str(users[i]) + history = users[:i] + + if len(history) > self.max_his_len: + history = history[-self.max_his_len:] + + one_data['users'] = '' + for user in history: + one_data['users'] = one_data['users'] + str(user) + ',' + one_data['users'] = one_data['users'][:-1] + + search_data.append(one_data) + + if self.sample_num > 0: + all_idx = range(len(search_data)) + sample_idx = np.random.choice(all_idx, self.sample_num, replace = False) + search_data = np.array(search_data)[sample_idx].tolist() + + return search_data + + def __len__(self): + return len(self.search_data) * self.prompt_sample_num + + def __getitem__(self, index): + idx = index // self.prompt_sample_num + d = self.search_data[idx] + + prompt_id = random.randint(0, len(self.prompts) - 1) + prompt = self.prompts[prompt_id] + + input = sft_prompt.format(instruction = prompt["instruction"], response = "") + output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"]) + return dict( + input_ids = input, + labels = output, + inters = 'placeholder', + item = d['item'], + users = d['users'], + user = d['user'], + task = 'usersearch' + ) + +# ===================================================================================================================== +# seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item +# ===================================================================================================================== + +# seqrec +class SeqRecDataset(BaseDataset): + def __init__(self, args, mode="train", + prompt_sample_num=1, prompt_id=0, sample_num=-1): + super().__init__(args) + + self.mode = mode + self.prompt_sample_num = prompt_sample_num + self.prompt_id = prompt_id + self.sample_num = sample_num + + self.prompts = all_prompt["seqrec"] + + self._load_data() + + if self.mode == 'train': + self.inter_data = self._process_train_data() + # self.inter_data = self.inter_data[:10] + elif self.mode == 'valid': + self.sample_valid = args.sample_valid + self.valid_prompt_id = args.valid_prompt_id + self.inter_data = self._process_valid_data() + # self.inter_data = self.inter_data[:10] + self._construct_valid_text() + elif self.mode == 'test': + self.inter_data = self._process_test_data() + # self.inter_data = self.inter_data[:10] + else: + raise NotImplementedError + + def _load_data(self): + with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: + self.inters = json.load(f) + + def _process_train_data(self): + inter_data = [] + for uid in self.inters: + items = self.inters[uid][:-2] + for i in range(1, len(items)): + one_data = dict() + one_data['user'] = uid + one_data['item'] = str(items[i]) + history = items[:i] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + one_data['inters'] = '' + for item in history: + one_data['inters'] = one_data['inters'] + str(item) + ',' + one_data['inters'] = one_data['inters'][:-1] + + inter_data.append(one_data) + + return inter_data + + def _process_valid_data(self): + inter_data = [] + for uid in self.inters: + one_data = dict() + items = self.inters[uid] + one_data['user'] = uid + one_data['item'] = str(items[-2]) + history = items[:-2] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + one_data['inters'] = '' + for item in history: + one_data['inters'] = one_data['inters'] + str(item) + ',' + one_data['inters'] = one_data['inters'][:-1] + inter_data.append(one_data) + + return inter_data + + def _process_test_data(self): + with open(self.index_file, 'r') as f: + self.indices = json.load(f) + self.remapped_inters = dict() + for uid, items in self.inters.items(): + new_items = ["".join(self.indices[str(i)]) for i in items] + self.remapped_inters[uid] = new_items + + with open(self.user_index_file, 'r') as f: + self.user_indices = json.load(f) + self.remapped_users = dict() + for uid in self.inters: + new_user= ''.join(self.user_indices[uid]) + self.remapped_users[uid] = new_user + + inter_data = [] + for uid in self.remapped_inters: + one_data = dict() + one_data['user'] = self.remapped_users[uid] + items = self.remapped_inters[uid] + one_data['item'] = items[-1] + history = items[:-1] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + one_data["inters"] = self.his_sep.join(history) + inter_data.append(one_data) + + # for uid in self.inters: + # one_data = dict() + # items = self.inters[uid] + # one_data["item"] = str(items[-1]) + # history = items[:-1] + # if self.max_his_len > 0: + # history = history[-self.max_his_len:] + # one_data['inters'] = '' + # for item in history: + # one_data['inters'] = one_data['inters'] + str(item) + ',' + # one_data['inters'] = one_data['inters'][:-1] + # inter_data.append(one_data) + + if self.sample_num > 0: + all_inter_idx = range(len(inter_data)) + sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace = False) + inter_data = np.array(inter_data)[sample_idx].tolist() + + return inter_data + + def set_prompt(self, prompt_id): + self.prompt_id = prompt_id + + def __len__(self): + if self.mode == 'train': + return len(self.inter_data) * self.prompt_sample_num + elif self.mode == 'valid': + return len(self.valid_text_data) + elif self.mode == 'test': + return len(self.inter_data) + else: + raise NotImplementedError + + def _construct_valid_text(self): + self.valid_text_data = [] + if self.sample_valid: + all_prompt_ids = range(len(self.prompts)) + for i in range(len(self.inter_data)): + d = self.inter_data[i] + prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False) + for prompt_id in prompt_ids: + prompt = self.prompts[prompt_id] + input = sft_prompt.format(instruction = prompt["instruction"], response = "") + output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"]) + self.valid_text_data.append({ + "input_ids": input, + "labels": output, + "inters": d['inters'], + "item": d['item'], + "users": 'placeholder', + "user": d['user'], + "task": 'seqrec'}) + else: + self.prompt_sample_num = 1 + prompt = self.prompts[self.valid_prompt_id] + for i in range(len(self.inter_data)): + d = self.inter_data[i] + input = sft_prompt.format(instruction = prompt["instruction"], response = "") + output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"]) + self.valid_text_data.append({ + "input_ids": input, "labels": output, + "inters": d['inters'], "item": d['item'], "users": 'placeholder', "user": d['user'], + "task": 'seqrec'}) + + def _get_text_data(self, data, prompt): + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = response) + + if self.mode == 'test': + return input, response + + return input, output + + def __getitem__(self, index): + if self.mode == 'valid': + return self.valid_text_data[index] + + idx = index // self.prompt_sample_num + d = self.inter_data[idx] + + if self.mode == 'train': + prompt_id = random.randint(0, len(self.prompts) - 1) + elif self.mode == 'test': + prompt_id = self.prompt_id + prompt = self.prompts[prompt_id] + instruction = prompt["instruction"].format(**d) + response = prompt["response"].format(**d) + input = sft_prompt.format(instruction = instruction, response = "") + return dict(input_ids = input, labels = response) + # output = prompt["response"] + # return dict(input_ids = input, labels = output, inters = d['inters'], item = d['item'], task = 'seqrec') + + prompt = self.prompts[prompt_id] + + input = sft_prompt.format(instruction = prompt["instruction"], response = "") + output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"]) + + return dict(input_ids = input, labels = output, inters = d['inters'], item = d['item'], user = d['user'], task = 'seqrec', users = 'placeholder') + +# itemsearch & query2item +class ItemSearchDataset(BaseDataset): + def __init__(self, args, mode="train", task = 'itemsearch', + prompt_sample_num=1, prompt_id=0, sample_num=-1): + super().__init__(args) + + self.mode = mode + self.prompt_sample_num = prompt_sample_num + self.prompt_id = prompt_id + self.sample_num = sample_num + + self.task = task.lower() + self.prompts = all_prompt[self.task] + + self._load_data() + self.search_data = self._process_data() + + def _load_data(self): + with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: + self.user_info = json.load(f) + + def _process_data(self): + search_data = [] + user_explicit_preference = self.user_info["user_explicit_preference"] + user_vague_intention = self.user_info["user_vague_intention"] + if self.mode == 'train': + user_vague_intention = user_vague_intention["train"] + elif self.mode == 'test': + user_vague_intention = user_vague_intention["test"] + else: + raise NotImplementedError + + for uid in user_explicit_preference.keys(): + one_data = {} + one_data['user'] = uid + user_ep = user_explicit_preference[uid] + user_vi = user_vague_intention[uid]["querys"] + one_data["explicit_preferences"] = user_ep + one_data["user_related_intention"] = user_vi[0] + one_data["item_related_intention"] = user_vi[1] + + iid = user_vague_intention[uid]["item"] + inters = user_vague_intention[uid]["inters"] + + if len(inters) == 0: + continue + + one_data["item"] = str(iid) + + if self.max_his_len > 0: + inters = inters[-self.max_his_len:] + one_data["inters"] = '' + for item in inters: + one_data["inters"] = one_data["inters"] + str(item) + ',' + one_data["inters"] = one_data["inters"][:-1] + + search_data.append(one_data) + + if self.sample_num > 0: + all_idx = range(len(search_data)) + sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) + search_data = np.array(search_data)[sample_idx].tolist() + + return search_data + + def set_prompt(self, prompt_id): + self.prompt_id = prompt_id + + def __len__(self): + if self.mode == 'train': + return len(self.search_data) * self.prompt_sample_num + elif self.mode == 'test': + return len(self.search_data) + else: + return len(self.search_data) + + def _get_text_data(self, data, prompt): + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = response) + + if self.mode == 'test': + return input, response + + return input, output + + def __getitem__(self, index): + idx = index // self.prompt_sample_num + + d = self.search_data[idx] + if self.mode == 'train': + prompt_id = random.randint(0, len(self.prompts) - 1) + elif self.mode == 'test': + prompt_id = self.prompt_id + + prompt = self.prompts[prompt_id] + + d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"])) + d["explicit_preference"] = d["explicit_preference"].replace('{','').replace('}','') + d["user_related_intention"] = d["user_related_intention"].replace('{','').replace('}','') + d["item_related_intention"] = d["item_related_intention"].replace('{','').replace('}','') + all_querys = [d["user_related_intention"], d["item_related_intention"]] + d["query"] = random.choice(all_querys) + + # d["query"] = d["query"].replace('{','').replace('}','') + + if self.task == 'itemsearch': + sub_d = d.copy() + sub_d.pop('inters') + sub_d.pop('user') + instruction = prompt["instruction"].format(inters='{inters}', user='{user}', **sub_d) + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = prompt["response"]) + return dict(input_ids = input, labels = output, inters = d['inters'], item = d['item'], user = d['user'], task = self.task, users = 'placeholder') + elif self.task == 'query2item': + sub_d = d.copy() + sub_d.pop('user') + instruction = prompt["instruction"].format(user='{user}', **sub_d) + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = prompt["response"]) + return dict(input_ids = input, labels = output, inters = 'placeholder', item = d['item'], user = d['user'], task = self.task, users = 'placeholder') + +# inters2title & inters2description & intertitles2item +class FusionSeqRecDataset(BaseDataset): + def __init__(self, args, mode="train", task = 'inters2title', + prompt_sample_num=1, prompt_id=0, sample_num=-1): + super().__init__(args) + + self.mode = mode + self.prompt_sample_num = prompt_sample_num + self.prompt_id = prompt_id + self.sample_num = sample_num + + self.task = task.lower() + self.prompts = all_prompt[self.task] + + # load data + self._load_data() + + # load data + if self.mode == 'train': + self.inter_data = self._process_train_data() + elif self.mode == 'valid': + self.sample_valid = args.sample_valid + self.valid_prompt_id = args.valid_prompt_id + self.inter_data = self._process_valid_data() + self._construct_valid_text() + elif self.mode == 'test': + self.inter_data = self._process_test_data() + else: + raise NotImplementedError + + def _load_data(self): + with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: + self.inters = json.load(f) + with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f: + self.item_feat = json.load(f) + + def _process_train_data(self): + + inter_data = [] + for uid in self.inters: + items = self.inters[uid][:-2] + for i in range(1, len(items)): + one_data = dict() + one_data["item"] = str(items[i]) + one_data['user'] = uid + one_data["title"] = self.item_feat[str(items[i])]["title"].strip().strip(".!?,;:`") + one_data["title"] = one_data["title"].replace('{','').replace('}','') + one_data["description"] = self.item_feat[str(items[i])]["description"] + one_data["description"] = one_data["description"].replace('{','').replace('}','') + + history = items[:i] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + + one_data['inters'] = '' + for item in history: + one_data['inters'] = one_data['inters'] + str(item) +',' + one_data['inters'] = one_data['inters'][:-1] + + inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`").replace('{','').replace('}','') + "\"" for j in history] + one_data["inter_titles"] = self.his_sep.join(inter_titles) + + inter_data.append(one_data) + + if self.sample_num > 0: + all_inter_idx = range(len(inter_data)) + sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) + inter_data = np.array(inter_data)[sample_idx].tolist() + + return inter_data + + def _process_valid_data(self): + inter_data = [] + for uid in self.inters: + items = self.inters[uid] + one_data = dict() + one_data["item"] = str(items[-2]) + one_data["title"] = self.item_feat[str(items[-2])]["title"].strip().strip(".!?,;:`") + one_data["description"] = self.item_feat[str(items[-2])]["description"] + one_data["description"] = one_data["description"].replace('{','').replace('}','') + + history = items[:-2] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + one_data['inters'] = '' + for item in history: + one_data['inters'] = one_data['inters'] + str(item) +',' + one_data['inters'] = one_data['inters'][:-1] + + inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] + one_data["inter_titles"] = self.his_sep.join(inter_titles) + + inter_data.append(one_data) + + if self.sample_num > 0: + all_inter_idx = range(len(inter_data)) + sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) + inter_data = np.array(inter_data)[sample_idx].tolist() + + return inter_data + + def _process_test_data(self): + inter_data = [] + for uid in self.inters: + items = self.inters[uid] + one_data = dict() + one_data["item"] = str(items[-1]) + one_data["title"] = self.item_feat[str(items[-1])]["title"].strip().strip(".!?,;:`") + one_data["description"] = self.item_feat[str(items[-1])]["description"] + + history = items[:-1] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + + one_data['inters'] = '' + for item in history: + one_data['inters'] = one_data['inters'] + str(item) +',' + one_data['inters'] = one_data['inters'][:-1] + + inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] + one_data["inter_titles"] = self.his_sep.join(inter_titles) + + inter_data.append(one_data) + + if self.sample_num > 0: + all_inter_idx = range(len(inter_data)) + sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) + inter_data = np.array(inter_data)[sample_idx].tolist() + + return inter_data + + def set_prompt(self, prompt_id): + self.prompt_id = prompt_id + + def __len__(self): + if self.mode == 'train': + return len(self.inter_data) * self.prompt_sample_num + elif self.mode == 'valid': + return len(self.valid_text_data) + elif self.mode == 'test': + return len(self.inter_data) + else: + raise NotImplementedError + + def _construct_valid_text(self): + self.valid_text_data = [] + if self.sample_valid: + all_prompt_ids = range(len(self.prompts)) + for i in range(len(self.inter_data)): + d = self.inter_data[i] + prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False) + if self.task == 'inters2title': + for prompt_id in prompt_ids: + prompt = self.prompts[prompt_id] + input = sft_prompt.format(instruction = prompt['instruction'], response = "") + response = prompt['response'].format(title = d['title']) + output = sft_prompt.format(instruction = prompt['instruction'], response = response) + self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': d['inters'], 'item': 'placeholder', 'task': self.task}) + elif self.task == 'inters2description': + for prompt_id in prompt_ids: + prompt = self.prompts[prompt_id] + input = sft_prompt.format(instruction = prompt['instruction'], response = "") + response = prompt['response'].format(title = d['description']) + output = sft_prompt.format(instruction = prompt['instruction'], response = response) + self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': d['inters'], 'item': 'placeholder', 'task': self.task}) + elif self.task == 'intertitles2item': + for prompt_id in prompt_ids: + prompt = self.prompts[prompt_id] + instruction = prompt['instruction'].format(inter_titles = d['inter_titles']) + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = prompt["response"]) + self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': 'placeholder', 'item': d['item'], 'task': self.task}) + else: + raise NotImplementedError + else: + self.prompt_sample_num = 1 + prompt = self.prompts[self.valid_prompt_id] + for i in range(len(self.inter_data)): + d = self.inter_data[i] + if self.task == 'inters2title': + input = sft_prompt.format(instruction = prompt['instruction'], response = "") + response = prompt['response'].format(title = d['title']) + output = sft_prompt.format(instruction = prompt['instruction'], response = response) + self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': d['inters'], 'item': 'placeholder', 'task': self.task}) + elif self.task == 'inters2description': + input = sft_prompt.format(instruction = prompt['instruction'], response = "") + response = prompt['response'].format(title = d['description']) + output = sft_prompt.format(instruction = prompt['instruction'], response = response) + self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': d['inters'], 'item': 'placeholder', 'task': self.task}) + elif self.task == 'intertitles2item': + instruction = prompt['instruction'].format(inter_titles = d['inter_titles']) + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = prompt["response"]) + self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': 'placeholder', 'item': d['item'], 'task': self.task}) + else: + raise NotImplementedError + + def _get_text_data(self, data, prompt): + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction=instruction, response="") + output = sft_prompt.format(instruction=instruction, response=response) + + if self.mode == 'test': + return input, response + + return input, output + + def __getitem__(self, index): + if self.mode == 'valid': + return self.valid_text_data[index] + + idx = index // self.prompt_sample_num + d = self.inter_data[idx] + + if self.mode == 'train': + prompt_id = random.randint(0, len(self.prompts) - 1) + elif self.mode == 'test': + prompt_id = self.prompt_id + + prompt = self.prompts[prompt_id] + + if self.task == 'inters2title': + input = sft_prompt.format(instruction = prompt['instruction'], response = "") + response = prompt['response'].format(title = d['title']) + output = sft_prompt.format(instruction = prompt['instruction'], response = response) + return dict(input_ids = input, labels = output, inters = d['inters'], user = d['user'], item = 'placeholder', task = self.task, users = 'placeholder') + elif self.task == 'inters2description': + input = sft_prompt.format(instruction = prompt['instruction'], response = "") + response = prompt['response'].format(description = d['description']) + output = sft_prompt.format(instruction = prompt['instruction'], response = response) + return dict(input_ids = input, labels = output, inters = d['inters'], user = d['user'], item = 'placeholder', task = self.task, users = 'placeholder') + elif self.task == 'intertitles2item': + instruction = prompt['instruction'].format(user = '{user}', inter_titles = d['inter_titles']) + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = prompt["response"]) + return dict(input_ids = input, labels = output, inters = 'placeholder', user = d['user'], item = d['item'], task = self.task, users = 'placeholder') + else: + raise NotImplementedError + +# preferenceobtain +class PreferenceObtainDataset(BaseDataset): + def __init__(self, args, prompt_sample_num=1, sample_num=-1): + super().__init__(args) + + self.prompt_sample_num = prompt_sample_num + self.sample_num = sample_num + + self.prompts = all_prompt["preferenceobtain"] + + # load data + self._load_data() + + self.preference_data = self._process_data() + + def _load_data(self): + with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: + self.user_info = json.load(f) + with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: + self.inters = json.load(f) + + def _process_data(self): + preference_data = [] + user_explicit_preference = self.user_info["user_explicit_preference"] + + for uid in user_explicit_preference.keys(): + one_data = {} + one_data['user'] = uid + inters = self.inters[uid][:-3] + user_ep = user_explicit_preference[uid] + + if self.max_his_len > 0: + inters = inters[-self.max_his_len:] + one_data['inters'] = '' + for item in inters: + one_data['inters'] = one_data['inters'] + str(item) + ',' + one_data['inters'] = one_data['inters'][:-1] + + one_data["explicit_preferences"] = user_ep + + preference_data.append(one_data) + + if self.sample_num > 0: + all_idx = range(len(preference_data)) + sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) + preference_data = np.array(preference_data)[sample_idx].tolist() + + return preference_data + + def set_prompt(self, prompt_id): + self.prompt_id = prompt_id + + def __len__(self): + return len(self.preference_data) * self.prompt_sample_num + + def _get_text_data(self, data, prompt): + + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = response) + + return input, output + + def __getitem__(self, index): + + idx = index // self.prompt_sample_num + + d = self.preference_data[idx] + prompt_id = random.randint(0, len(self.prompts) - 1) + + prompt = self.prompts[prompt_id] + + d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"])) + d["explicit_preference"] = d["explicit_preference"].replace('{','').replace('}','') + + input = sft_prompt.format(instruction = prompt["instruction"], response = "") + response = prompt["response"].format(**d) + output = sft_prompt.format(instruction = prompt["instruction"], response = response) + return dict(input_ids = input, labels = output, inters = d['inters'], user = d['user'], item = 'placeholder', task = 'preferenceobtain', users = 'placeholder') + +# item2index & index2item +class ItemFeatDataset(BaseDataset): + def __init__(self, args, task="item2index", prompt_sample_num=1, sample_num=-1): + super().__init__(args) + + self.task = task.lower() + self.prompt_sample_num = prompt_sample_num + self.sample_num = sample_num + + self.prompts = all_prompt[self.task] + + self._load_data() + self.feat_data = self._process_data() + + def _load_data(self): + with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f: + self.item_feat = json.load(f) + + def _process_data(self): + feat_data = [] + for iid in self.item_feat: + feat = self.item_feat[iid] + feat["item"] = iid + feat["title"] = feat["title"].strip().strip(".!?,;:`") + feat["title"] = feat["title"].replace('{','').replace('}','') + feat["description"] = feat["description"].strip().strip(".!?,;:`") + feat["description"] = feat["description"].replace('{','').replace('}','') + feat_data.append(feat) + + if self.sample_num > 0: + all_idx = range(len(feat_data)) + sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) + feat_data = np.array(feat_data)[sample_idx].tolist() + + return feat_data + + def __len__(self): + return len(self.feat_data) * self.prompt_sample_num + + def _get_text_data(self, data, prompt): + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = response) + + return input, output + + def __getitem__(self, index): + idx = index // self.prompt_sample_num + d = self.feat_data[idx] + + prompt_id = random.randint(0, len(self.prompts) - 1) + + prompt = self.prompts[prompt_id] + + if self.task == 'item2index': + instruction = prompt["instruction"].format(**d) + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = prompt["response"]) + return dict(input_ids = input, labels = output, inters = 'placeholder', user = 'placeholder', item = d['item'], task = self.task, users = 'placeholder') + elif self.task == 'index2item': + input = sft_prompt.format(instruction = prompt["instruction"], response = "") + response = prompt["response"].format(**d) + output = sft_prompt.format(instruction = prompt["instruction"], response = response) + return dict(input_ids = input, labels = output, inters = 'placeholder', user = 'placeholder', item = d['item'], task = self.task, users = 'placeholder') + else: + raise NotImplementedError + + + +class SeqRecTestDataset(BaseDataset): + + def __init__(self, args, prompt_id=0, sample_num=-1): + super().__init__(args) + + self.prompt_id = prompt_id + self.sample_num = sample_num + + self.prompt = all_prompt["seqrec"][self.prompt_id] + + # load data + self._load_data() + self._remap_items() + + self.inter_data = self._process_test_data() + + def _load_data(self): + + with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: + self.inters = json.load(f) + with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f: + self.indices = json.load(f) + + + def _remap_items(self): + + self.remapped_inters = dict() + for uid, items in self.inters.items(): + new_items = ["".join(self.indices[str(i)]) for i in items] + self.remapped_inters[uid] = new_items + + def _process_test_data(self): + + inter_data = [] + for uid in self.remapped_inters: + items = self.remapped_inters[uid] + one_data = dict() + # one_data["user"] = uid + one_data["item"] = items[-1] + history = items[:-1] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + if self.add_prefix: + history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] + one_data["inters"] = self.his_sep.join(history) + inter_data.append(one_data) + + if self.sample_num > 0: + all_inter_idx = range(len(inter_data)) + sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) + + inter_data = np.array(inter_data)[sample_idx].tolist() + + return inter_data + + def set_prompt(self, prompt_id): + self.prompt_id = prompt_id + + self.prompt = all_prompt["seqrec"][self.prompt_id] + + def __len__(self): + + return len(self.inter_data) + + def _get_text_data(self, data, prompt): + + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction=instruction, response="") + + return input, response + + def __getitem__(self, index): + + d = self.inter_data[index] + input, target = self._get_text_data(d, self.prompt) + + return dict(input_ids=input, labels=target) \ No newline at end of file diff --git a/data_finetune.py b/data_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..3325c49681fca1c1a7270c050add177a92c9f2fc --- /dev/null +++ b/data_finetune.py @@ -0,0 +1,1026 @@ +import copy +import random +import argparse +import os +import torch +import torch.nn as nn +from torch.utils.data import Dataset +from tqdm import tqdm +from collections import defaultdict +import torch.distributed as dist +import logging +import re +import pdb +import json +from prompt_finetune import sft_prompt, all_prompt +import numpy as np + +class BaseDataset(Dataset): + + def __init__(self, args): + super().__init__() + + self.args = args + self.dataset = args.dataset + self.data_path = os.path.join(args.data_path, self.dataset) + + self.max_his_len = args.max_his_len + self.his_sep = args.his_sep + self.index_file = args.index_file + self.user_index_file = args.user_index_file + self.add_prefix = args.add_prefix + + self.new_tokens = None + self.allowed_tokens = None + self.all_items = None + + def _load_data(self): + + with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f: + self.indices = json.load(f) + + def get_new_tokens(self): + + if self.new_tokens is not None: + return self.new_tokens + + self.new_tokens = set() + for index in self.indices.values(): + for token in index: + self.new_tokens.add(token) + self.new_tokens = sorted(list(self.new_tokens)) + + return self.new_tokens + + def get_all_items(self): + + if self.all_items is not None: + return self.all_items + + self.all_items = set() + for index in self.indices.values(): + self.all_items.add("".join(index)) + + return self.all_items + + def get_prefix_allowed_tokens_fn(self, tokenizer): + + if self.allowed_tokens is None: + self.allowed_tokens = {} + for index in self.indices.values(): + for i, token in enumerate(index): + token_id = tokenizer(token)["input_ids"][1] + if i not in self.allowed_tokens.keys(): + self.allowed_tokens[i] = set() + self.allowed_tokens[i].add(token_id) + self.allowed_tokens[len(self.allowed_tokens.keys())] = set([tokenizer.eos_token_id]) + sep = tokenizer("Response:")["input_ids"][1:] + + def prefix_allowed_tokens_fn(batch_id, sentence): + sentence = sentence.tolist() + reversed_sent = sentence[::-1] + for i in range(len(reversed_sent)): + if reversed_sent[i:i + len(sep)] == sep[::-1]: + # print(list(self.allowed_tokens[i])) + return list(self.allowed_tokens[i]) + + return prefix_allowed_tokens_fn + + def _process_data(self): + + raise NotImplementedError + +class UserSearchFinetune(BaseDataset): + def __init__(self, args, prompt_sample_num = 1, prompt_id = 0, sample_num = -1): + super().__init__(args) + + self.prompt_sample_num = prompt_sample_num + self.prompt_id = prompt_id + self.sample_num = sample_num + + self.prompts = all_prompt["usersearch"] + + self._load_data() + self._remap_items() + self.search_data = self._process_data() + + def _load_data(self): + with open(os.path.join(self.data_path, self.dataset + ".inter.user.json"), 'r') as f: + self.user_inters = json.load(f) + with open(self.user_index_file, 'r') as f: + self.user_indices = json.load(f) + with open(self.index_file, 'r') as f: + self.indices = json.load(f) + + def _remap_items(self): + self.remapped_user_inters = dict() + for iid, users in self.user_inters.items(): + new_users = ["".join(self.user_indices[str(i)]) for i in users] + self.remapped_user_inters[iid] = new_users + + def _process_data(self): + search_data = [] + for iid in self.remapped_user_inters.keys(): + users = self.remapped_user_inters[iid] + for i in range(1, len(users)): + one_data = {} + one_data['item'] = self.indices[iid] + one_data['user'] = users[i] + history = users[:i] + + if len(history) > self.max_his_len: + history = history[-self.max_his_len:] + + one_data['users'] = self.his_sep.join(history) + + # one_data['users'] = '' + # for user in history: + # one_data['users'] = one_data['users'] + str(user) + ',' + # one_data['users'] = one_data['users'][:-1] + + search_data.append(one_data) + + if self.sample_num > 0: + all_idx = range(len(search_data)) + sample_idx = np.random.choice(all_idx, self.sample_num, replace = False) + search_data = np.array(search_data)[sample_idx].tolist() + + return search_data + + def __len__(self): + return len(self.search_data) * self.prompt_sample_num + + def __getitem__(self, index): + idx = index // self.prompt_sample_num + d = self.search_data[idx] + + prompt_id = random.randint(0, len(self.prompts) - 1) + prompt = self.prompts[prompt_id] + + instruction = prompt["instruction"].format(**d) + response = prompt["response"].format(**d) + + input = sft_prompt.format(instruction = prompt["instruction"], response = "") + output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"]) + + return dict(input_ids = input, labels = output) + # return dict( + # input_ids = input, + # labels = output, + # inters = 'placeholder', + # item = d['item'], + # users = d['users'], + # user = d['user'], + # task = 'usersearch' + # ) + +class UserFeatFinetune(BaseDataset): + def __init__(self, args, task = "pref2user", prompt_sample_num = 1, sample_num = -1): + super().__init__(args) + + self.task = task.lower() + self.prompt_sample_num = prompt_sample_num + self.sample_num = sample_num + + self.prompts = all_prompt[self.task] + + self._load_data() + self.feat_data = self._process_data() + + def _load_data(self): + with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: + user_feat = json.load(f) + self.user_feat = user_feat['user_explicit_preference'] + with open(self.user_index_file, 'r') as f: + self.user_indices = json.load(f) + + def _process_data(self): + feat_data = [] + for uid in self.user_feat: + one_data = {} + one_data['user'] = self.user_indices[uid] + + preference = " ".join(self.user_feat[uid]) + preference = preference.strip().strip(".!?,;:`") + preference = preference.replace('{','').replace('}','') + one_data['preference'] = preference + + feat_data.append(one_data) + + if self.sample_num > 0: + all_idx = range(len(feat_data)) + sample_idx = np.random.choice(all_idx, self.sample_num, replace = False) + feat_data = np.array(feat_data)[sample_idx].tolist() + + return feat_data + + def __len__(self): + return len(self.feat_data) * self.prompt_sample_num + + def __getitem__(self, index): + idx = index // self.prompt_sample_num + d = self.feat_data[idx] + prompt_id = random.randint(0, len(self.prompts) - 1) + prompt = self.prompts[prompt_id] + + instruction = prompt["instruction"].format(**d) + response = prompt["response"].format(**d) + + input = sft_prompt.format(instruction = prompt["instruction"], response = "") + output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"]) + + return dict(input_ids = input, labels = output) + + # if self.task == 'pref2user': + # instruction = prompt['instruction'].format(preference = d['preference']) + # input = sft_prompt.format(instruction = instruction, response = "") + # output = sft_prompt.format(instruction = instruction, response = prompt["response"]) + # return dict( + # input_ids = input, + # labels = output, + # inters = 'placeholder', + # item = 'placeholder', + # users = 'placeholder', + # user = d['user'], + # task = self.task + # ) + # elif self.task == 'user2pref': + # input = sft_prompt.format(instruction = prompt["instruction"], response = "") + # response = prompt["response"].format(preference = d['preference']) + # output = sft_prompt.format(instruction = prompt["instruction"], response = response) + # return dict( + # input_ids = input, + # labels = output, + # inters = 'placeholder', + # item = 'placeholder', + # users = 'placeholder', + # user = d['user'], + # task = self.task + # ) + # else: + # raise NotImplementedError + +class SeqRecFinetune(BaseDataset): + + def __init__(self, args, mode="train", + prompt_sample_num=1, prompt_id=0, sample_num=-1): + super().__init__(args) + + self.mode = mode + self.prompt_sample_num = prompt_sample_num + self.prompt_id = prompt_id + self.sample_num = sample_num + + self.prompts = all_prompt["seqrec"] + + # load data + self._load_data() + self._remap_items() + + # load data + if self.mode == 'train': + self.inter_data = self._process_train_data() + elif self.mode == 'valid': + self.sample_valid = args.sample_valid + self.valid_prompt_id = args.valid_prompt_id + self.inter_data = self._process_valid_data() + self._construct_valid_text() + elif self.mode == 'test': + self.inter_data = self._process_test_data() + else: + raise NotImplementedError + + def _load_data(self): + + with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: + self.inters = json.load(f) + with open(self.index_file, 'r') as f: + self.indices = json.load(f) + with open(self.user_index_file, 'r') as f: + self.user_indices = json.load(f) + + def _remap_items(self): + + self.remapped_inters = dict() + for uid, items in self.inters.items(): + new_items = ["".join(self.indices[str(i)]) for i in items] + self.remapped_inters[uid] = new_items + + def _process_train_data(self): + + inter_data = [] + for uid in self.remapped_inters: + items = self.remapped_inters[uid][:-2] + for i in range(1, len(items)): + one_data = dict() + one_data["user"] = self.user_indices[uid] + one_data["item"] = items[i] + history = items[:i] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + if self.add_prefix: + history = [str(k+1) + ". " + item_idx for k, item_idx in enumerate(history)] + one_data["inters"] = self.his_sep.join(history) + inter_data.append(one_data) + + return inter_data + + def _process_valid_data(self): + + inter_data = [] + for uid in self.remapped_inters: + items = self.remapped_inters[uid] + one_data = dict() + # one_data["user"] = uid + one_data["user"] = self.user_indices[uid] + one_data["item"] = items[-2] + history = items[:-2] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + if self.add_prefix: + history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] + one_data["inters"] = self.his_sep.join(history) + inter_data.append(one_data) + + return inter_data + + def _process_test_data(self): + + inter_data = [] + for uid in self.remapped_inters: + items = self.remapped_inters[uid] + one_data = dict() + # one_data["user"] = uid + one_data["user"] = self.user_indices[uid] + one_data["item"] = items[-1] + history = items[:-1] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + if self.add_prefix: + history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] + one_data["inters"] = self.his_sep.join(history) + inter_data.append(one_data) + + if self.sample_num > 0: + all_inter_idx = range(len(inter_data)) + sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) + inter_data = np.array(inter_data)[sample_idx].tolist() + + return inter_data + + def set_prompt(self, prompt_id): + + self.prompt_id = prompt_id + + def __len__(self): + if self.mode == 'train': + return len(self.inter_data) * self.prompt_sample_num + elif self.mode == 'valid': + return len(self.valid_text_data) + elif self.mode == 'test': + return len(self.inter_data) + else: + raise NotImplementedError + + def _construct_valid_text(self): + self.valid_text_data = [] + if self.sample_valid: + all_prompt_ids = range(len(self.prompts)) + for i in range(len(self.inter_data)): + d = self.inter_data[i] + prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False) + for prompt_id in prompt_ids: + prompt = self.prompts[prompt_id] + input, output = self._get_text_data(d, prompt) + self.valid_text_data.append({"input_ids": input, "labels": output}) + else: + self.prompt_sample_num = 1 + prompt = self.prompts[self.valid_prompt_id] + for i in range(len(self.inter_data)): + d = self.inter_data[i] + input, output = self._get_text_data(d, prompt) + self.valid_text_data.append({"input_ids": input, "labels": output}) + + def _get_text_data(self, data, prompt): + + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = response) + + if self.mode == 'test': + return input, response + + return input, output + + def __getitem__(self, index): + + if self.mode == 'valid': + return self.valid_text_data[index] + + idx = index // self.prompt_sample_num + d = self.inter_data[idx] + # print(index, idx) + + if self.mode == 'train': + prompt_id = random.randint(0, len(self.prompts) - 1) + elif self.mode == 'test': + prompt_id = self.prompt_id + + prompt = self.prompts[prompt_id] + + input, output = self._get_text_data(d, prompt) + + # print({"input": input, "output": output}) + + return dict(input_ids=input, labels=output) + + +class FusionSeqRecFinetune(BaseDataset): + + def __init__(self, args, mode="train", + prompt_sample_num=1, prompt_id=0, sample_num=-1): + super().__init__(args) + + self.mode = mode + self.prompt_sample_num = prompt_sample_num + self.prompt_id = prompt_id + self.sample_num = sample_num + + self.prompts = all_prompt["fusionseqrec"] + + # load data + self._load_data() + # self._remap_items() + + # load data + if self.mode == 'train': + self.inter_data = self._process_train_data() + elif self.mode == 'valid': + self.sample_valid = args.sample_valid + self.valid_prompt_id = args.valid_prompt_id + self.inter_data = self._process_valid_data() + self._construct_valid_text() + elif self.mode == 'test': + self.inter_data = self._process_test_data() + else: + raise NotImplementedError + + + def _load_data(self): + + with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: + self.inters = json.load(f) + with open(self.index_file, 'r') as f: + self.indices = json.load(f) + with open(self.user_index_file, 'r') as f: + self.user_indices = json.load(f) + # with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f: + # self.indices = json.load(f) + with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f: + self.item_feat = json.load(f) + + def _process_train_data(self): + + inter_data = [] + for uid in self.inters: + items = self.inters[uid][:-2] + for i in range(1, len(items)): + one_data = dict() + # one_data["user"] = uid + one_data["user"] = self.user_indices[uid] + one_data["item"] = "".join(self.indices[str(items[i])]) + one_data["title"] = self.item_feat[str(items[i])]["title"].strip().strip(".!?,;:`") + one_data["description"] = self.item_feat[str(items[i])]["description"] + history = items[:i] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + inters = ["".join(self.indices[str(j)]) for j in history] + inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] + + + if self.add_prefix: + inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] + inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] + + one_data["inters"] = self.his_sep.join(inters) + one_data["inter_titles"] = self.his_sep.join(inter_titles) + inter_data.append(one_data) + + if self.sample_num > 0: + all_inter_idx = range(len(inter_data)) + sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) + inter_data = np.array(inter_data)[sample_idx].tolist() + + return inter_data + + def _process_valid_data(self): + + inter_data = [] + for uid in self.inters: + items = self.inters[uid] + one_data = dict() + one_data["item"] = "".join(self.indices[str(items[-2])]) + one_data["title"] = self.item_feat[str(items[-2])]["title"].strip().strip(".!?,;:`") + one_data["description"] = self.item_feat[str(items[-2])]["description"] + + + history = items[:-2] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + inters = ["".join(self.indices[str(j)]) for j in history] + inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] + + if self.add_prefix: + inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] + inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] + + one_data["inters"] = self.his_sep.join(inters) + one_data["inter_titles"] = self.his_sep.join(inter_titles) + inter_data.append(one_data) + + if self.sample_num > 0: + all_inter_idx = range(len(inter_data)) + sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) + inter_data = np.array(inter_data)[sample_idx].tolist() + + return inter_data + + def _process_test_data(self): + + inter_data = [] + for uid in self.inters: + items = self.inters[uid] + one_data = dict() + one_data["item"] = "".join(self.indices[str(items[-1])]) + one_data["title"] = self.item_feat[str(items[-1])]["title"].strip().strip(".!?,;:`") + one_data["description"] = self.item_feat[str(items[-1])]["description"] + + history = items[:-1] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + inters = ["".join(self.indices[str(j)]) for j in history] + inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history] + + if self.add_prefix: + inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] + inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)] + + one_data["inters"] = self.his_sep.join(inters) + one_data["inter_titles"] = self.his_sep.join(inter_titles) + inter_data.append(one_data) + + if self.sample_num > 0: + all_inter_idx = range(len(inter_data)) + sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) + inter_data = np.array(inter_data)[sample_idx].tolist() + + return inter_data + + def set_prompt(self, prompt_id): + + self.prompt_id = prompt_id + + def __len__(self): + if self.mode == 'train': + return len(self.inter_data) * self.prompt_sample_num + elif self.mode == 'valid': + return len(self.valid_text_data) + elif self.mode == 'test': + return len(self.inter_data) + else: + raise NotImplementedError + + def _construct_valid_text(self): + self.valid_text_data = [] + if self.sample_valid: + all_prompt_ids = range(len(self.prompts)) + for i in range(len(self.inter_data)): + d = self.inter_data[i] + prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False) + for prompt_id in prompt_ids: + prompt = self.prompts[prompt_id] + input, output = self._get_text_data(d, prompt) + self.valid_text_data.append({"input_ids": input, "labels": output}) + else: + self.prompt_sample_num = 1 + prompt = self.prompts[self.valid_prompt_id] + for i in range(len(self.inter_data)): + d = self.inter_data[i] + input, output = self._get_text_data(d, prompt) + self.valid_text_data.append({"input_ids": input, "labels": output}) + + def _get_text_data(self, data, prompt): + + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction=instruction, response="") + output = sft_prompt.format(instruction=instruction, response=response) + + if self.mode == 'test': + return input, response + + return input, output + + def __getitem__(self, index): + + if self.mode == 'valid': + return self.valid_text_data[index] + + idx = index // self.prompt_sample_num + d = self.inter_data[idx] + + if self.mode == 'train': + prompt_id = random.randint(0, len(self.prompts) - 1) + elif self.mode == 'test': + prompt_id = self.prompt_id + + prompt = self.prompts[prompt_id] + + input, output = self._get_text_data(d, prompt) + + + return dict(input_ids=input, labels=output) + + +class ItemFeatFinetune(BaseDataset): + + def __init__(self, args, task="item2index", prompt_sample_num=1, sample_num=-1): + super().__init__(args) + + self.task = task.lower() + self.prompt_sample_num = prompt_sample_num + self.sample_num = sample_num + + self.prompts = all_prompt[self.task] + + # load data + self._load_data() + self.feat_data = self._process_data() + + + + def _load_data(self): + + # with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f: + # self.indices = json.load(f) + with open(self.index_file, 'r') as f: + self.indices = json.load(f) + with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f: + self.item_feat = json.load(f) + + + def _process_data(self): + + feat_data = [] + for iid in self.item_feat: + feat = self.item_feat[iid] + index = "".join(self.indices[iid]) + feat["item"] = index + feat["title"] = feat["title"].strip().strip(".!?,;:`") + feat_data.append(feat) + + if self.sample_num > 0: + all_idx = range(len(feat_data)) + sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) + + feat_data = np.array(feat_data)[sample_idx].tolist() + + return feat_data + + + def __len__(self): + return len(self.feat_data) * self.prompt_sample_num + + def _get_text_data(self, data, prompt): + + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = response) + + return input, output + + def __getitem__(self, index): + + idx = index // self.prompt_sample_num + d = self.feat_data[idx] + + prompt_id = random.randint(0, len(self.prompts) - 1) + + prompt = self.prompts[prompt_id] + + input, output = self._get_text_data(d, prompt) + + return dict(input_ids=input, labels=output) + + +class ItemSearchFinetune(BaseDataset): + + def __init__(self, args, mode="train", + prompt_sample_num=1, prompt_id=0, sample_num=-1): + super().__init__(args) + + self.mode = mode + self.prompt_sample_num = prompt_sample_num + self.prompt_id = prompt_id + self.sample_num = sample_num + + self.prompts = all_prompt["itemsearch"] + + # load data + self._load_data() + self.search_data = self._process_data() + + + + def _load_data(self): + + # with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f: + # self.indices = json.load(f) + with open(self.index_file, 'r') as f: + self.indices = json.load(f) + with open(self.user_index_file, 'r') as f: + self.user_indices = json.load(f) + with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: + self.user_info = json.load(f) + + + def _process_data(self): + + search_data = [] + user_explicit_preference = self.user_info["user_explicit_preference"] + user_vague_intention = self.user_info["user_vague_intention"] + if self.mode == 'train': + user_vague_intention = user_vague_intention["train"] + elif self.mode == 'test': + user_vague_intention = user_vague_intention["test"] + else: + raise NotImplementedError + + for uid in user_explicit_preference.keys(): + one_data = {} + user_ep = user_explicit_preference[uid] + user_vi = user_vague_intention[uid]["querys"] + one_data["explicit_preferences"] = user_ep + one_data["user_related_intention"] = user_vi[0] + one_data["item_related_intention"] = user_vi[1] + one_data["user"] = self.user_indices[uid] + + iid = user_vague_intention[uid]["item"] + inters = user_vague_intention[uid]["inters"] + + index = "".join(self.indices[str(iid)]) + one_data["item"] = index + + if self.max_his_len > 0: + inters = inters[-self.max_his_len:] + inters = ["".join(self.indices[str(i)]) for i in inters] + if self.add_prefix: + inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] + + one_data["inters"] = self.his_sep.join(inters) + + search_data.append(one_data) + + if self.sample_num > 0: + all_idx = range(len(search_data)) + sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) + + search_data = np.array(search_data)[sample_idx].tolist() + + return search_data + + def set_prompt(self, prompt_id): + self.prompt_id = prompt_id + + def __len__(self): + if self.mode == 'train': + return len(self.search_data) * self.prompt_sample_num + elif self.mode == 'test': + return len(self.search_data) + else: + return len(self.search_data) + + + def _get_text_data(self, data, prompt): + + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = response) + + if self.mode == 'test': + return input, response + + return input, output + + def __getitem__(self, index): + + idx = index // self.prompt_sample_num + + d = self.search_data[idx] + if self.mode == 'train': + prompt_id = random.randint(0, len(self.prompts) - 1) + elif self.mode == 'test': + prompt_id = self.prompt_id + + prompt = self.prompts[prompt_id] + + d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"])) + all_querys = [d["user_related_intention"], d["item_related_intention"]] + d["query"] = random.choice(all_querys) + + input, output = self._get_text_data(d, prompt) + + return dict(input_ids=input, labels=output) + + + +class PreferenceObtainFinetune(BaseDataset): + + def __init__(self, args, prompt_sample_num=1, sample_num=-1): + super().__init__(args) + + self.prompt_sample_num = prompt_sample_num + self.sample_num = sample_num + + self.prompts = all_prompt["preferenceobtain"] + + # load data + self._load_data() + self._remap_items() + + self.preference_data = self._process_data() + + + + def _load_data(self): + + with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f: + self.user_info = json.load(f) + with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: + self.inters = json.load(f) + # with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f: + # self.indices = json.load(f) + with open(self.index_file, 'r') as f: + self.indices = json.load(f) + with open(self.user_index_file, 'r') as f: + self.user_indices = json.load(f) + + + def _remap_items(self): + + self.remapped_inters = dict() + for uid, items in self.inters.items(): + new_items = ["".join(self.indices[str(i)]) for i in items] + self.remapped_inters[uid] = new_items + + def _process_data(self): + + preference_data = [] + user_explicit_preference = self.user_info["user_explicit_preference"] + + for uid in user_explicit_preference.keys(): + one_data = {} + one_data["user"] = self.user_indices[uid] + inters = self.remapped_inters[uid][:-3] + user_ep = user_explicit_preference[uid] + + if self.max_his_len > 0: + inters = inters[-self.max_his_len:] + if self.add_prefix: + inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)] + + one_data["explicit_preferences"] = user_ep + one_data["inters"] = self.his_sep.join(inters) + + preference_data.append(one_data) + + if self.sample_num > 0: + all_idx = range(len(preference_data)) + sample_idx = np.random.choice(all_idx, self.sample_num, replace=False) + + preference_data = np.array(preference_data)[sample_idx].tolist() + + return preference_data + + def set_prompt(self, prompt_id): + self.prompt_id = prompt_id + + def __len__(self): + return len(self.preference_data) * self.prompt_sample_num + + + def _get_text_data(self, data, prompt): + + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction = instruction, response = "") + output = sft_prompt.format(instruction = instruction, response = response) + + return input, output + + def __getitem__(self, index): + + idx = index // self.prompt_sample_num + + d = self.preference_data[idx] + prompt_id = random.randint(0, len(self.prompts) - 1) + + prompt = self.prompts[prompt_id] + + d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"])) + + input, output = self._get_text_data(d, prompt) + + return dict(input_ids=input, labels=output) + + + + + +class SeqRecTestDataset(BaseDataset): + + def __init__(self, args, prompt_id=0, sample_num=-1): + super().__init__(args) + + self.prompt_id = prompt_id + self.sample_num = sample_num + + self.prompt = all_prompt["seqrec"][self.prompt_id] + + # load data + self._load_data() + self._remap_items() + + self.inter_data = self._process_test_data() + + def _load_data(self): + + with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f: + self.inters = json.load(f) + with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f: + self.indices = json.load(f) + + + def _remap_items(self): + + self.remapped_inters = dict() + for uid, items in self.inters.items(): + new_items = ["".join(self.indices[str(i)]) for i in items] + self.remapped_inters[uid] = new_items + + def _process_test_data(self): + + inter_data = [] + for uid in self.remapped_inters: + items = self.remapped_inters[uid] + one_data = dict() + # one_data["user"] = uid + one_data["item"] = items[-1] + history = items[:-1] + if self.max_his_len > 0: + history = history[-self.max_his_len:] + if self.add_prefix: + history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)] + one_data["inters"] = self.his_sep.join(history) + inter_data.append(one_data) + + if self.sample_num > 0: + all_inter_idx = range(len(inter_data)) + sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False) + + inter_data = np.array(inter_data)[sample_idx].tolist() + + return inter_data + + def set_prompt(self, prompt_id): + self.prompt_id = prompt_id + + self.prompt = all_prompt["seqrec"][self.prompt_id] + + def __len__(self): + + return len(self.inter_data) + + def _get_text_data(self, data, prompt): + + instruction = prompt["instruction"].format(**data) + response = prompt["response"].format(**data) + + input = sft_prompt.format(instruction=instruction, response="") + + return input, response + + def __getitem__(self, index): + + d = self.inter_data[index] + input, target = self._get_text_data(d, self.prompt) + + return dict(input_ids=input, labels=target) \ No newline at end of file diff --git a/data_process/amazon18_data_process.py b/data_process/amazon18_data_process.py new file mode 100644 index 0000000000000000000000000000000000000000..4dc45b5711d5632a4f115acf4ba656cfaab513dc --- /dev/null +++ b/data_process/amazon18_data_process.py @@ -0,0 +1,299 @@ +import argparse +import collections +import gzip +import html +import json +import os +import random +import re +import torch +from tqdm import tqdm +import numpy as np +from utils import check_path, clean_text, amazon18_dataset2fullname, write_json_file, write_remap_index + +def load_ratings(file): + users, items, inters = set(), set(), set() + with open(file, 'r') as fp: + for line in tqdm(fp, desc='Load ratings'): + try: + item, user, rating, time = line.strip().split(',') + users.add(user) + items.add(item) + inters.add((user, item, float(rating), int(time))) + except ValueError: + print(line) + return users, items, inters + + +def load_meta_items(file): + items = {} + with gzip.open(file, "r") as fp: + for line in tqdm(fp, desc="Load metas"): + data = json.loads(line) + item = data["asin"] + title = clean_text(data["title"]) + + descriptions = data["description"] + descriptions = clean_text(descriptions) + + brand = data["brand"].replace("by\n", "").strip() + + categories = data["category"] + new_categories = [] + for category in categories: + if "" in category: + break + new_categories.append(category.strip()) + categories = ",".join(new_categories).strip() + + items[item] = {"title": title, "description": descriptions, "brand": brand, "categories": categories} + # print(items[item]) + return items + + +def load_review_data(args, user2id, item2id): + + dataset_full_name = amazon18_dataset2fullname[args.dataset] + review_file_path = os.path.join(args.input_path, 'Review', dataset_full_name + '.json.gz') + + reviews = {} + + with gzip.open(review_file_path, "r") as fp: + + for line in tqdm(fp,desc='Load reviews'): + inter = json.loads(line) + try: + user = inter['reviewerID'] + item = inter['asin'] + if user in user2id and item in item2id: + uid = user2id[user] + iid = item2id[item] + else: + continue + if 'reviewText' in inter: + review = clean_text(inter['reviewText']) + else: + review = '' + if 'summary' in inter: + summary = clean_text(inter['summary']) + else: + summary = '' + reviews[str((uid,iid))]={"review":review, "summary":summary} + + except ValueError: + print(line) + + return reviews + + +def get_user2count(inters): + user2count = collections.defaultdict(int) + for unit in inters: + user2count[unit[0]] += 1 + return user2count + + +def get_item2count(inters): + item2count = collections.defaultdict(int) + for unit in inters: + item2count[unit[1]] += 1 + return item2count + + +def generate_candidates(unit2count, threshold): + cans = set() + for unit, count in unit2count.items(): + if count >= threshold: + cans.add(unit) + return cans, len(unit2count) - len(cans) + + +def filter_inters(inters, can_items=None, + user_k_core_threshold=0, item_k_core_threshold=0): + new_inters = [] + + # filter by meta items + if can_items: + print('\nFiltering by meta items: ') + for unit in inters: + if unit[1] in can_items.keys(): + new_inters.append(unit) + inters, new_inters = new_inters, [] + print(' The number of inters: ', len(inters)) + + # filter by k-core + if user_k_core_threshold or item_k_core_threshold: + print('\nFiltering by k-core:') + idx = 0 + user2count = get_user2count(inters) + item2count = get_item2count(inters) + + while True: + new_user2count = collections.defaultdict(int) + new_item2count = collections.defaultdict(int) + users, n_filtered_users = generate_candidates( # users is set + user2count, user_k_core_threshold) + items, n_filtered_items = generate_candidates( + item2count, item_k_core_threshold) + if n_filtered_users == 0 and n_filtered_items == 0: + break + for unit in inters: + if unit[0] in users and unit[1] in items: + new_inters.append(unit) + new_user2count[unit[0]] += 1 + new_item2count[unit[1]] += 1 + idx += 1 + inters, new_inters = new_inters, [] + user2count, item2count = new_user2count, new_item2count + print(' Epoch %d The number of inters: %d, users: %d, items: %d' + % (idx, len(inters), len(user2count), len(item2count))) + return inters + + +def make_inters_in_order(inters): + user2inters, new_inters = collections.defaultdict(list), list() + for inter in inters: + user, item, rating, timestamp = inter + user2inters[user].append((user, item, rating, timestamp)) + for user in user2inters: + user_inters = user2inters[user] + user_inters.sort(key=lambda d: d[3]) + interacted_item = set() + for inter in user_inters: + if inter[1] in interacted_item: # 过滤重复交互 + continue + interacted_item.add(inter[1]) + new_inters.append(inter) + return new_inters + + +def preprocess_rating(args): + dataset_full_name = amazon18_dataset2fullname[args.dataset] + + print('Process rating data: ') + print(' Dataset: ', args.dataset) + + # load ratings + rating_file_path = os.path.join(args.input_path, 'Ratings', dataset_full_name + '.csv') + rating_users, rating_items, rating_inters = load_ratings(rating_file_path) + + # load item IDs with meta data + meta_file_path = os.path.join(args.input_path, 'Metadata', f'meta_{dataset_full_name}.json.gz') + meta_items = load_meta_items(meta_file_path) + + # 1. Filter items w/o meta data; + # 2. K-core filtering; + print('The number of raw inters: ', len(rating_inters)) + + rating_inters = make_inters_in_order(rating_inters) + + rating_inters = filter_inters(rating_inters, can_items=meta_items, + user_k_core_threshold=args.user_k, + item_k_core_threshold=args.item_k) + + # sort interactions chronologically for each user + rating_inters = make_inters_in_order(rating_inters) + print('\n') + + # return: list of (user_ID, item_ID, rating, timestamp) + return rating_inters, meta_items + +def convert_inters2dict(inters): + user2items = collections.defaultdict(list) + user2index, item2index = dict(), dict() + for inter in inters: + user, item, rating, timestamp = inter + if user not in user2index: + user2index[user] = len(user2index) + if item not in item2index: + item2index[item] = len(item2index) + user2items[user2index[user]].append(item2index[item]) + return user2items, user2index, item2index + +def generate_data(args, rating_inters): + print('Split dataset: ') + print(' Dataset: ', args.dataset) + + # generate train valid temp + user2items, user2index, item2index = convert_inters2dict(rating_inters) + train_inters, valid_inters, test_inters = dict(), dict(), dict() + for u_index in range(len(user2index)): + inters = user2items[u_index] + # leave one out + train_inters[u_index] = [str(i_index) for i_index in inters[:-2]] + valid_inters[u_index] = [str(inters[-2])] + test_inters[u_index] = [str(inters[-1])] + assert len(user2items[u_index]) == len(train_inters[u_index]) + \ + len(valid_inters[u_index]) + len(test_inters[u_index]) + return user2items, train_inters, valid_inters, test_inters, user2index, item2index + +def convert_to_atomic_files(args, train_data, valid_data, test_data): + print('Convert dataset: ') + print(' Dataset: ', args.dataset) + uid_list = list(train_data.keys()) + uid_list.sort(key=lambda t: int(t)) + + with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.train.inter'), 'w') as file: + file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n') + for uid in uid_list: + item_seq = train_data[uid] + seq_len = len(item_seq) + for target_idx in range(1, seq_len): + target_item = item_seq[-target_idx] + seq = item_seq[:-target_idx][-50:] + file.write(f'{uid}\t{" ".join(seq)}\t{target_item}\n') + + with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.valid.inter'), 'w') as file: + file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n') + for uid in uid_list: + item_seq = train_data[uid][-50:] + target_item = valid_data[uid][0] + file.write(f'{uid}\t{" ".join(item_seq)}\t{target_item}\n') + + with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.test.inter'), 'w') as file: + file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n') + for uid in uid_list: + item_seq = (train_data[uid] + valid_data[uid])[-50:] + target_item = test_data[uid][0] + file.write(f'{uid}\t{" ".join(item_seq)}\t{target_item}\n') + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games') + parser.add_argument('--user_k', type=int, default=5, help='user k-core filtering') + parser.add_argument('--item_k', type=int, default=5, help='item k-core filtering') + parser.add_argument('--input_path', type=str, default='') + parser.add_argument('--output_path', type=str, default='') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + # load interactions from raw rating file + rating_inters, meta_items = preprocess_rating(args) + + + # split train/valid/temp + all_inters,train_inters, valid_inters, test_inters, user2index, item2index = generate_data(args, rating_inters) + + check_path(os.path.join(args.output_path, args.dataset)) + + write_json_file(all_inters, os.path.join(args.output_path, args.dataset, f'{args.dataset}.inter.json')) + convert_to_atomic_files(args, train_inters, valid_inters, test_inters) + + item2feature = collections.defaultdict(dict) + for item, item_id in item2index.items(): + item2feature[item_id] = meta_items[item] + + # reviews = load_review_data(args, user2index, item2index) + + print("user:",len(user2index)) + print("item:",len(item2index)) + + write_json_file(item2feature, os.path.join(args.output_path, args.dataset, f'{args.dataset}.item.json')) + # write_json_file(reviews, os.path.join(args.output_path, args.dataset, f'{args.dataset}.review.json')) + + + write_remap_index(user2index, os.path.join(args.output_path, args.dataset, f'{args.dataset}.user2id')) + write_remap_index(item2index, os.path.join(args.output_path, args.dataset, f'{args.dataset}.item2id')) \ No newline at end of file diff --git a/data_process/amazon18_recbole_data_process.py b/data_process/amazon18_recbole_data_process.py new file mode 100644 index 0000000000000000000000000000000000000000..45aa42ae21ed41ff60e91db5b1f8019efcc90ef1 --- /dev/null +++ b/data_process/amazon18_recbole_data_process.py @@ -0,0 +1,226 @@ +import argparse +import collections +import gzip +import html +import json +import os +import random +import re +import torch +from tqdm import tqdm +import numpy as np +from utils import check_path, clean_text, amazon18_dataset2fullname,write_json_file,write_remap_index + +def load_ratings(file): + users, items, inters = set(), set(), set() + with open(file, 'r') as fp: + for line in tqdm(fp, desc='Load ratings'): + try: + item, user, rating, time = line.strip().split(',') + users.add(user) + items.add(item) + inters.add((user, item, float(rating), int(time))) + except ValueError: + print(line) + return users, items, inters + + +def load_meta_items(file): + items = {} + # re_tag = re.compile(']*>') + with gzip.open(file, "r") as fp: + for line in tqdm(fp, desc="Load metas"): + data = json.loads(line) + item = data["asin"] + title = clean_text(data["title"]) + + descriptions = data["description"] + descriptions = clean_text(descriptions) + # new_descriptions = [] + # for description in descriptions: + # description = re.sub(re_tag, '', description) + # new_descriptions.append(description.strip()) + # descriptions = " ".join(new_descriptions).strip() + + brand = data["brand"].replace("by\n", "").strip() + + categories = data["category"] + new_categories = [] + for category in categories: + if "" in category: + break + new_categories.append(category.strip()) + categories = ",".join(new_categories[1:]).strip() + + items[item] = {"title": title, "description": descriptions, "brand": brand, "categories": categories} + # print(items[item]) + return items + + +def get_user2count(inters): + user2count = collections.defaultdict(int) + for unit in inters: + user2count[unit[0]] += 1 + return user2count + + +def get_item2count(inters): + item2count = collections.defaultdict(int) + for unit in inters: + item2count[unit[1]] += 1 + return item2count + + +def generate_candidates(unit2count, threshold): + cans = set() + for unit, count in unit2count.items(): + if count >= threshold: + cans.add(unit) + return cans, len(unit2count) - len(cans) + + +def filter_inters(inters, can_items=None, + user_k_core_threshold=0, item_k_core_threshold=0): + new_inters = [] + + # filter by meta items + if can_items: + print('\nFiltering by meta items: ') + for unit in inters: + if unit[1] in can_items.keys(): + new_inters.append(unit) + inters, new_inters = new_inters, [] + print(' The number of inters: ', len(inters)) + + # filter by k-core + if user_k_core_threshold or item_k_core_threshold: + print('\nFiltering by k-core:') + idx = 0 + user2count = get_user2count(inters) + item2count = get_item2count(inters) + + while True: + new_user2count = collections.defaultdict(int) + new_item2count = collections.defaultdict(int) + users, n_filtered_users = generate_candidates( # users is set + user2count, user_k_core_threshold) + items, n_filtered_items = generate_candidates( + item2count, item_k_core_threshold) + if n_filtered_users == 0 and n_filtered_items == 0: + break + for unit in inters: + if unit[0] in users and unit[1] in items: + new_inters.append(unit) + new_user2count[unit[0]] += 1 + new_item2count[unit[1]] += 1 + idx += 1 + inters, new_inters = new_inters, [] + user2count, item2count = new_user2count, new_item2count + print(' Epoch %d The number of inters: %d, users: %d, items: %d' + % (idx, len(inters), len(user2count), len(item2count))) + return inters + + +def make_inters_in_order(inters): + user2inters, new_inters = collections.defaultdict(list), list() + for inter in inters: + user, item, rating, timestamp = inter + user2inters[user].append((user, item, rating, timestamp)) + for user in user2inters: + user_inters = user2inters[user] + user_inters.sort(key=lambda d: d[3]) + interacted_item = set() + for inter in user_inters: + if inter[1] in interacted_item: # 过滤重复交互 + continue + interacted_item.add(inter[1]) + new_inters.append(inter) + return new_inters + + +def preprocess_rating(args): + dataset_full_name = amazon18_dataset2fullname[args.dataset] + + print('Process rating data: ') + print(' Dataset: ', args.dataset) + + # load ratings + rating_file_path = os.path.join(args.input_path, 'Ratings', dataset_full_name + '.csv') + rating_users, rating_items, rating_inters = load_ratings(rating_file_path) + + # load item IDs with meta data + meta_file_path = os.path.join(args.input_path, 'Metadata', f'meta_{dataset_full_name}.json.gz') + meta_items = load_meta_items(meta_file_path) + + # 1. Filter items w/o meta data; + # 2. K-core filtering; + print('The number of raw inters: ', len(rating_inters)) + + rating_inters = make_inters_in_order(rating_inters) + + rating_inters = filter_inters(rating_inters, can_items=meta_items, + user_k_core_threshold=args.user_k, + item_k_core_threshold=args.item_k) + + # sort interactions chronologically for each user + rating_inters = make_inters_in_order(rating_inters) + print('\n') + + # return: list of (user_ID, item_ID, rating, timestamp) + return rating_inters, meta_items + +def save_inter(args, inters): + print('Convert dataset: ') + print(' Dataset: ', args.dataset) + + with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.inter'), 'w') as file: + file.write('user_id:token\titem_id:token\trating:float\ttimestamp:float\n') + for inter in inters: + user, item, rating, timestamp = inter + file.write(f'{user}\t{item}\t{rating}\t{timestamp}\n') + + +def save_feat(args, feat, all_items): + iid_list = list(feat.keys()) + num_item = 0 + with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.item'), 'w') as file: + # "title": title, "description": descriptions, "brand": brand, "categories": categories + file.write('item_id:token\ttitle:token_seq\tbrand:token\tcategories:token_seq\n') + for iid in iid_list: + if iid in all_items: + num_item += 1 + title, brand, categories = feat[iid]["title"], feat[iid]["brand"], feat[iid]["categories"] + file.write(f'{iid}\t{title}\t{brand}\t{categories}\n') + print("num_item: ", num_item) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games') + parser.add_argument('--user_k', type=int, default=5, help='user k-core filtering') + parser.add_argument('--item_k', type=int, default=5, help='item k-core filtering') + parser.add_argument('--input_path', type=str, default='') + parser.add_argument('--output_path', type=str, default='') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + # load interactions from raw rating file + rating_inters, meta_items = preprocess_rating(args) + + check_path(os.path.join(args.output_path, args.dataset)) + + + all_items = set() + for inter in rating_inters: + user, item, rating, timestamp = inter + all_items.add(item) + + print("total item: ", len(list(all_items))) + + save_inter(args,rating_inters) + save_feat(args,meta_items, all_items) + + diff --git a/data_process/amazon_text_emb.py b/data_process/amazon_text_emb.py new file mode 100644 index 0000000000000000000000000000000000000000..7818592a5e96d380cc232c0c3b2d4b95b4e93a71 --- /dev/null +++ b/data_process/amazon_text_emb.py @@ -0,0 +1,139 @@ +import argparse +import collections +import gzip +import html +import json +import os +import random +import re +import torch +from tqdm import tqdm +import numpy as np +from utils import * +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, AutoTokenizer, AutoModel + + +def load_data(args): + + item2feature_path = os.path.join(args.root, f'{args.dataset}.item.json') + item2feature = load_json(item2feature_path) + + return item2feature + +def generate_text(item2feature, features): + item_text_list = [] + + for item in item2feature: + data = item2feature[item] + text = [] + for meta_key in features: + if meta_key in data: + meta_value = clean_text(data[meta_key]) + text.append(meta_value.strip()) + + item_text_list.append([int(item), text]) + + return item_text_list + +def preprocess_text(args): + print('Process text data: ') + print(' Dataset: ', args.dataset) + + item2feature = load_data(args) + # load item text and clean + item_text_list = generate_text(item2feature, ['title', 'description']) + # item_text_list = generate_text(item2feature, ['title']) + # return: list of (item_ID, cleaned_item_text) + return item_text_list + +def generate_item_embedding(args, item_text_list, tokenizer, model, word_drop_ratio=-1): + print(f'Generate Text Embedding: ') + print(' Dataset: ', args.dataset) + + items, texts = zip(*item_text_list) + order_texts = [[0]] * len(items) + for item, text in zip(items, texts): + order_texts[item] = text + for text in order_texts: + assert text != [0] + + embeddings = [] + start, batch_size = 0, 1 + with torch.no_grad(): + while start < len(order_texts): + if (start+1)%100==0: + print("==>",start+1) + field_texts = order_texts[start: start + batch_size] + # print(field_texts) + field_texts = zip(*field_texts) + + field_embeddings = [] + for sentences in field_texts: + sentences = list(sentences) + # print(sentences) + if word_drop_ratio > 0: + print(f'Word drop with p={word_drop_ratio}') + new_sentences = [] + for sent in sentences: + new_sent = [] + sent = sent.split(' ') + for wd in sent: + rd = random.random() + if rd > word_drop_ratio: + new_sent.append(wd) + new_sent = ' '.join(new_sent) + new_sentences.append(new_sent) + sentences = new_sentences + encoded_sentences = tokenizer(sentences, max_length=args.max_sent_len, + truncation=True, return_tensors='pt',padding="longest").to(args.device) + outputs = model(input_ids=encoded_sentences.input_ids, + attention_mask=encoded_sentences.attention_mask) + + masked_output = outputs.last_hidden_state * encoded_sentences['attention_mask'].unsqueeze(-1) + mean_output = masked_output.sum(dim=1) / encoded_sentences['attention_mask'].sum(dim=-1, keepdim=True) + mean_output = mean_output.detach().cpu() + field_embeddings.append(mean_output) + + field_mean_embedding = torch.stack(field_embeddings, dim=0).mean(dim=0) + embeddings.append(field_mean_embedding) + start += batch_size + + embeddings = torch.cat(embeddings, dim=0).numpy() + print('Embeddings shape: ', embeddings.shape) + + file = os.path.join(args.root, args.dataset + '.emb-' + args.plm_name + "-td" + ".npy") + np.save(file, embeddings) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games') + parser.add_argument('--root', type=str, default="") + parser.add_argument('--gpu_id', type=int, default=2, help='ID of running GPU') + parser.add_argument('--plm_name', type=str, default='llama') + parser.add_argument('--plm_checkpoint', type=str, + default='') + parser.add_argument('--max_sent_len', type=int, default=2048) + parser.add_argument('--word_drop_ratio', type=float, default=-1, help='word drop ratio, do not drop by default') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + args.root = os.path.join(args.root, args.dataset) + + device = set_device(args.gpu_id) + args.device = device + + item_text_list = preprocess_text(args) + + plm_tokenizer, plm_model = load_plm(args.plm_checkpoint) + if plm_tokenizer.pad_token_id is None: + plm_tokenizer.pad_token_id = 0 + plm_model = plm_model.to(device) + + generate_item_embedding(args, item_text_list,plm_tokenizer, + plm_model, word_drop_ratio=args.word_drop_ratio) + + diff --git a/data_process/amazon_user_emb.py b/data_process/amazon_user_emb.py new file mode 100644 index 0000000000000000000000000000000000000000..9511b57a051657117aaea753bfe1d5558c4385d2 --- /dev/null +++ b/data_process/amazon_user_emb.py @@ -0,0 +1,145 @@ +import argparse +import collections +import gzip +import html +import json +import os +import random +import re +import torch +from tqdm import tqdm +import numpy as np +from utils import * +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, AutoTokenizer, AutoModel + + +def load_data(args): + + item2feature_path = os.path.join(args.root, f'{args.dataset}.user.json') + item2feature = load_json(item2feature_path) + + return item2feature + +def generate_text(item2feature, features): + item_text_list = [] + + for item in item2feature: + data = item2feature[item] + text = [] + + for i in range(len(data)): + meta_value = clean_text(data[i]) + text.append(meta_value.strip()) + + # for meta_key in features: + # if meta_key in data: + # meta_value = clean_text(data[meta_key]) + # text.append(meta_value.strip()) + + item_text_list.append([int(item), text]) + + return item_text_list + +def preprocess_text(args): + print('Process text data ......') + print('Dataset:', args.dataset) + + item2feature = load_data(args) + item2feature = item2feature['user_explicit_preference'] + # load item text and clean + item_text_list = generate_text(item2feature) + # item_text_list = generate_text(item2feature, ['user_explicit_preference']) + # item_text_list = generate_text(item2feature, ['title']) + # return: list of (item_ID, cleaned_item_text) + return item_text_list + +def generate_item_embedding(args, item_text_list, tokenizer, model, word_drop_ratio=-1): + print(f'Generate Text Embedding ......') + print('Dataset:', args.dataset) + + items, texts = zip(*item_text_list) + order_texts = [[0]] * len(items) + for item, text in zip(items, texts): + order_texts[item] = text + for text in order_texts: + assert text != [0] + + embeddings = [] + start, batch_size = 0, 1 + with torch.no_grad(): + while start < len(order_texts): + if (start+1) % 100 == 0: + print("==>", start + 1) + field_texts = order_texts[start: start + batch_size] + # print(field_texts) + field_texts = zip(*field_texts) + + field_embeddings = [] + for sentences in field_texts: + sentences = list(sentences) + # print(sentences) + if word_drop_ratio > 0: + print(f'Word drop with p={word_drop_ratio}') + new_sentences = [] + for sent in sentences: + new_sent = [] + sent = sent.split(' ') + for wd in sent: + rd = random.random() + if rd > word_drop_ratio: + new_sent.append(wd) + new_sent = ' '.join(new_sent) + new_sentences.append(new_sent) + sentences = new_sentences + encoded_sentences = tokenizer(sentences, max_length=args.max_sent_len, + truncation=True, return_tensors='pt',padding="longest").to(args.device) + outputs = model(input_ids=encoded_sentences.input_ids, + attention_mask=encoded_sentences.attention_mask) + + masked_output = outputs.last_hidden_state * encoded_sentences['attention_mask'].unsqueeze(-1) + mean_output = masked_output.sum(dim=1) / encoded_sentences['attention_mask'].sum(dim=-1, keepdim=True) + mean_output = mean_output.detach().cpu() + field_embeddings.append(mean_output) + + field_mean_embedding = torch.stack(field_embeddings, dim=0).mean(dim=0) + embeddings.append(field_mean_embedding) + start += batch_size + + embeddings = torch.cat(embeddings, dim=0).numpy() + print('Embeddings shape: ', embeddings.shape) + + # file = os.path.join(args.root, args.dataset + '.emb-' + args.plm_name + "-td" + ".npy") + np.save(args.save_path, embeddings) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games') + parser.add_argument('--root', type=str, default="") + parser.add_argument('--gpu_id', type=int, default=2, help='ID of running GPU') + parser.add_argument('--plm_name', type=str, default='llama') + parser.add_argument('--plm_checkpoint', type=str, + default='') + parser.add_argument('--max_sent_len', type=int, default=2048) + parser.add_argument('--word_drop_ratio', type=float, default=-1, help='word drop ratio, do not drop by default') + parser.add_argument('--save_path', type=str, default="") + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + args.root = os.path.join(args.root, args.dataset) + + device = set_device(args.gpu_id) + args.device = device + + item_text_list = preprocess_text(args) + + plm_tokenizer, plm_model = load_plm(args.plm_checkpoint) + if plm_tokenizer.pad_token_id is None: + plm_tokenizer.pad_token_id = 0 + plm_model = plm_model.to(device) + + generate_item_embedding(args, item_text_list,plm_tokenizer, + plm_model, word_drop_ratio=args.word_drop_ratio) \ No newline at end of file diff --git a/data_process/get_llm_output.py b/data_process/get_llm_output.py new file mode 100644 index 0000000000000000000000000000000000000000..a7a7cd1ee8e9710319cfd16cdb09283ebd59b699 --- /dev/null +++ b/data_process/get_llm_output.py @@ -0,0 +1,374 @@ + + +import argparse +import os +import os.path as osp +import random +import time +from logging import getLogger +import openai +from utils import get_res_batch, load_json, intention_prompt, preference_prompt_1, preference_prompt_2, amazon18_dataset2fullname, write_json_file +import json + + + +def get_intention_train(args, inters, item2feature, reviews, api_info): + + intention_train_output_file = os.path.join(args.root,"intention_train.json") + + + # Suggest modifying the prompt based on different datasets + prompt = intention_prompt + dataset_full_name = amazon18_dataset2fullname[args.dataset] + dataset_full_name = dataset_full_name.replace("_", " ").lower() + print(dataset_full_name) + + prompt_list = [] + + inter_data = [] + + for (user,item_list) in inters.items(): + user = int(user) + item = int(item_list[-3]) + history = item_list[:-3] + + inter_data.append((user,item,history)) + + review = reviews[str((user, item))]["review"] + item_title = item2feature[str(item)]["title"] + input_prompt = prompt.format(item_title=item_title,dataset_full_name=dataset_full_name,review=review) + prompt_list.append(input_prompt) + + st = 0 + with open(intention_train_output_file, mode='a') as f: + + while st < len(prompt_list): + # while st < 3: + print(st) + # if st < 25631: + # st += args.batchsize + # continue + + + res = get_res_batch(args.model_name, prompt_list[st:st+args.batchsize], args.max_tokens, api_info) + + for i, answer in enumerate(res): + user, item, history = inter_data[st+i] + # print(answer) + # print("=============") + + if answer == '': + print("answer null error") + answer = "I enjoy high-quality item." + + if answer.strip().count('\n') != 1: + if 'haracteristics:' in answer: + answer = answer.strip().split("The item's characteristics:") + else: + answer = answer.strip().split("The item's characteristic:") + else: + answer = answer.strip().split('\n') + + if '' in answer: + answer.remove('') + + if len(answer) == 1: + print(answer) + user_preference = item_character = answer[0] + elif len(answer) >= 3: + print(answer) + answer = answer[-1] + user_preference = item_character = answer + else: + user_preference, item_character = answer + + if ':' in user_preference: + idx = user_preference.index(':') + user_preference = user_preference[idx+1:] + user_preference = user_preference.strip().replace('}','') + user_preference = user_preference.replace('\n','') + + if ':' in item_character: + idx = item_character.index(':') + item_character = item_character[idx+1:] + item_character = item_character.strip().replace('}','') + item_character = item_character.replace('\n','') + + + dict = {"user":user, "item":item, "inters": history, + "user_related_intention":user_preference, "item_related_intention": item_character} + + json.dump(dict, f) + f.write("\n") + + st += args.batchsize + + return intention_train_output_file + + +def get_intention_test(args, inters, item2feature, reviews, api_info): + + intention_test_output_file = os.path.join(args.root,"intention_test.json") + + # Suggest modifying the prompt based on different datasets + prompt = intention_prompt + dataset_full_name = amazon18_dataset2fullname[args.dataset] + dataset_full_name = dataset_full_name.replace("_", " ").lower() + print(dataset_full_name) + + prompt_list = [] + + inter_data = [] + + for (user,item_list) in inters.items(): + user = int(user) + item = int(item_list[-1]) + history = item_list[:-1] + + inter_data.append((user,item,history)) + + review = reviews[str((user, item))]["review"] + item_title = item2feature[str(item)]["title"] + input_prompt = prompt.format(item_title=item_title,dataset_full_name=dataset_full_name,review=review) + prompt_list.append(input_prompt) + + st = 0 + with open(intention_test_output_file, mode='a') as f: + + while st < len(prompt_list): + # while st < 3: + print(st) + # if st < 4623: + # st += args.batchsize + # continue + + res = get_res_batch(args.model_name, prompt_list[st:st+args.batchsize], args.max_tokens, api_info) + + for i, answer in enumerate(res): + user, item, history = inter_data[st+i] + + if answer == '': + print("answer null error") + answer = "I enjoy high-quality item." + + if answer.strip().count('\n') != 1: + if 'haracteristics:' in answer: + answer = answer.strip().split("The item's characteristics:") + else: + answer = answer.strip().split("The item's characteristic:") + else: + answer = answer.strip().split('\n') + + if '' in answer: + answer.remove('') + + if len(answer) == 1: + print(answer) + user_preference = item_character = answer[0] + elif len(answer) >= 3: + print(answer) + answer = answer[-1] + user_preference = item_character = answer + else: + user_preference, item_character = answer + + if ':' in user_preference: + idx = user_preference.index(':') + user_preference = user_preference[idx+1:] + user_preference = user_preference.strip().replace('}','') + user_preference = user_preference.replace('\n','') + + if ':' in item_character: + idx = item_character.index(':') + item_character = item_character[idx+1:] + item_character = item_character.strip().replace('}','') + item_character = item_character.replace('\n','') + + + dict = {"user":user, "item":item, "inters": history, + "user_related_intention":user_preference, "item_related_intention": item_character} + + json.dump(dict, f) + f.write("\n") + + st += args.batchsize + + return intention_test_output_file + + + + +def get_user_preference(args, inters, item2feature, reviews, api_info): + + preference_output_file = os.path.join(args.root,"user_preference.json") + + + # Suggest modifying the prompt based on different datasets + prompt_1 = preference_prompt_1 + prompt_2 = preference_prompt_2 + + + dataset_full_name = amazon18_dataset2fullname[args.dataset] + dataset_full_name = dataset_full_name.replace("_", " ").lower() + print(dataset_full_name) + + prompt_list_1 = [] + prompt_list_2 = [] + + users = [] + + for (user,item_list) in inters.items(): + users.append(user) + history = item_list[:-3] + item_titles = [] + for j, item in enumerate(history): + item_titles.append(str(j+1) + '.' + item2feature[str(item)]["title"]) + if len(item_titles) > args.max_his_len: + item_titles = item_titles[-args.max_his_len:] + item_titles = ", ".join(item_titles) + + input_prompt_1 = prompt_1.format(dataset_full_name=dataset_full_name, item_titles=item_titles) + input_prompt_2 = prompt_2.format(dataset_full_name=dataset_full_name, item_titles=item_titles) + + prompt_list_1.append(input_prompt_1) + prompt_list_2.append(input_prompt_2) + + + st = 0 + with open(preference_output_file, mode='a') as f: + + while st < len(prompt_list_1): + # while st < 3: + print(st) + # if st < 22895: + # st += args.batchsize + # continue + + res_1 = get_res_batch(args.model_name, prompt_list_1[st:st + args.batchsize], args.max_tokens, api_info) + res_2 = get_res_batch(args.model_name, prompt_list_2[st:st + args.batchsize], args.max_tokens, api_info) + for i, answers in enumerate(zip(res_1, res_2)): + + user = users[st + i] + + answer_1, answer_2 = answers + # print(answers) + # print("=============") + + if answer_1 == '': + print("answer null error") + answer_1 = "I enjoy high-quality item." + + if answer_2 == '': + print("answer null error") + answer_2 = "I enjoy high-quality item." + + if answer_2.strip().count('\n') != 1: + if 'references:' in answer_2: + answer_2 = answer_2.strip().split("Short-term preferences:") + else: + answer_2 = answer_2.strip().split("Short-term preference:") + else: + answer_2 = answer_2.strip().split('\n') + + if '' in answer_2: + answer_2.remove('') + + if len(answer_2) == 1: + print(answer_2) + long_preference = short_preference = answer_2[0] + elif len(answer_2) >= 3: + print(answer_2) + answer_2 = answer_2[-1] + long_preference = short_preference = answer_2 + else: + long_preference, short_preference = answer_2 + + if ':' in long_preference: + idx = long_preference.index(':') + long_preference = long_preference[idx+1:] + long_preference = long_preference.strip().replace('}','') + long_preference = long_preference.replace('\n','') + + if ':' in short_preference: + idx = short_preference.index(':') + short_preference = short_preference[idx+1:] + short_preference = short_preference.strip().replace('}','') + short_preference = short_preference.replace('\n','') + + dict = {"user":user,"user_preference":[answer_1, long_preference, short_preference]} + # print(dict) + json.dump(dict, f) + f.write("\n") + + st += args.batchsize + + return preference_output_file + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--dataset', type=str, default='Instruments', help='Instruments / Arts / Games') + parser.add_argument('--root', type=str, default='') + parser.add_argument('--api_info', type=str, default='./api_info.json') + parser.add_argument('--model_name', type=str, default='text-davinci-003') + parser.add_argument('--max_tokens', type=int, default=512) + parser.add_argument('--batchsize', type=int, default=16) + parser.add_argument('--max_his_len', type=int, default=20) + return parser.parse_args() + +if __name__ == "__main__": + args = parse_args() + + args.root = os.path.join(args.root, args.dataset) + + api_info = load_json(args.api_info) + openai.api_key = api_info["api_key_list"].pop() + + + inter_path = os.path.join(args.root, f'{args.dataset}.inter.json') + inters = load_json(inter_path) + + + item2feature_path = os.path.join(args.root, f'{args.dataset}.item.json') + item2feature = load_json(item2feature_path) + + reviews_path = os.path.join(args.root, f'{args.dataset}.review.json') + reviews = load_json(reviews_path) + + intention_train_output_file = get_intention_train(args, inters, item2feature, reviews, api_info) + intention_test_output_file = get_intention_test(args, inters, item2feature, reviews ,api_info) + preference_output_file = get_user_preference(args, inters, item2feature, reviews, api_info) + + intention_train = {} + intention_test = {} + user_preference = {} + + with open(intention_train_output_file, "r") as f: + for line in f: + # print(line) + content = json.loads(line) + if content["user"] not in intention_train: + intention_train[content["user"]] = {"item":content["item"], + "inters":content["inters"], + "querys":[ content["user_related_intention"], content["item_related_intention"] ]} + + + with open(intention_test_output_file, "r") as f: + for line in f: + content = json.loads(line) + if content["user"] not in intention_train: + intention_test[content["user"]] = {"item":content["item"], + "inters":content["inters"], + "querys":[ content["user_related_intention"], content["item_related_intention"] ]} + + + with open(preference_output_file, "r") as f: + for line in f: + content = json.loads(line) + user_preference[content["user"]] = content["user_preference"] + + user_dict = { + "user_explicit_preference": user_preference, + "user_vague_intention": {"train": intention_train, "test": intention_test}, + } + + write_json_file(user_dict, os.path.join(args.root, f'{args.dataset}.user.json')) diff --git a/data_process/utils.py b/data_process/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c32b5f588d8f124be50b33c26ad499643151df9d --- /dev/null +++ b/data_process/utils.py @@ -0,0 +1,238 @@ +import html +import json +import os +import pickle +import re +import time + +import torch +# import gensim +from transformers import AutoModel, AutoTokenizer +import collections +import openai + + + +def get_res_batch(model_name, prompt_list, max_tokens, api_info): + + while True: + try: + res = openai.Completion.create( + model=model_name, + prompt=prompt_list, + temperature=0.4, + max_tokens=max_tokens, + top_p=1, + frequency_penalty=0, + presence_penalty=0 + ) + output_list = [] + for choice in res['choices']: + output = choice['text'].strip() + output_list.append(output) + + return output_list + + except openai.error.AuthenticationError as e: + print(e) + openai.api_key = api_info["api_key_list"].pop() + time.sleep(10) + except openai.error.RateLimitError as e: + print(e) + if str(e) == "You exceeded your current quota, please check your plan and billing details.": + openai.api_key = api_info["api_key_list"].pop() + time.sleep(10) + else: + print('\nopenai.error.RateLimitError\nRetrying...') + time.sleep(10) + except openai.error.ServiceUnavailableError as e: + print(e) + print('\nopenai.error.ServiceUnavailableError\nRetrying...') + time.sleep(10) + except openai.error.Timeout: + print('\nopenai.error.Timeout\nRetrying...') + time.sleep(10) + except openai.error.APIError as e: + print(e) + print('\nopenai.error.APIError\nRetrying...') + time.sleep(10) + except openai.error.APIConnectionError as e: + print(e) + print('\nopenai.error.APIConnectionError\nRetrying...') + time.sleep(10) + except Exception as e: + print(e) + return None + + + + +def check_path(path): + if not os.path.exists(path): + os.makedirs(path) + + +def set_device(gpu_id): + if gpu_id == -1: + return torch.device('cpu') + else: + return torch.device( + 'cuda:' + str(gpu_id) if torch.cuda.is_available() else 'cpu') + +def load_plm(model_path='bert-base-uncased'): + + tokenizer = AutoTokenizer.from_pretrained(model_path,) + + print("Load Model:", model_path) + + model = AutoModel.from_pretrained(model_path,low_cpu_mem_usage=True,) + return tokenizer, model + +def load_json(file): + with open(file, 'r') as f: + data = json.load(f) + return data + +def clean_text(raw_text): + if isinstance(raw_text, list): + new_raw_text=[] + for raw in raw_text: + raw = html.unescape(raw) + raw = re.sub(r']*>', '', raw) + raw = re.sub(r'["\n\r]*', '', raw) + new_raw_text.append(raw.strip()) + cleaned_text = ' '.join(new_raw_text) + else: + if isinstance(raw_text, dict): + cleaned_text = str(raw_text)[1:-1].strip() + else: + cleaned_text = raw_text.strip() + cleaned_text = html.unescape(cleaned_text) + cleaned_text = re.sub(r']*>', '', cleaned_text) + cleaned_text = re.sub(r'["\n\r]*', '', cleaned_text) + index = -1 + while -index < len(cleaned_text) and cleaned_text[index] == '.': + index -= 1 + index += 1 + if index == 0: + cleaned_text = cleaned_text + '.' + else: + cleaned_text = cleaned_text[:index] + '.' + if len(cleaned_text) >= 2000: + cleaned_text = '' + return cleaned_text + +def load_pickle(filename): + with open(filename, "rb") as f: + return pickle.load(f) + + +def make_inters_in_order(inters): + user2inters, new_inters = collections.defaultdict(list), list() + for inter in inters: + user, item, rating, timestamp = inter + user2inters[user].append((user, item, rating, timestamp)) + for user in user2inters: + user_inters = user2inters[user] + user_inters.sort(key=lambda d: d[3]) + for inter in user_inters: + new_inters.append(inter) + return new_inters + +def write_json_file(dic, file): + print('Writing json file: ',file) + with open(file, 'w') as fp: + json.dump(dic, fp, indent=4) + +def write_remap_index(unit2index, file): + print('Writing remap file: ',file) + with open(file, 'w') as fp: + for unit in unit2index: + fp.write(unit + '\t' + str(unit2index[unit]) + '\n') + + +intention_prompt = "After purchasing a {dataset_full_name} item named \"{item_title}\", the user left a comment expressing his opinion and personal preferences. The user's comment is as follows: \n\"{review}\" " \ + "\nAs we all know, user comments often contain information about both their personal preferences and the characteristics of the item they interacted with. From this comment, you can infer both the user's personal preferences and the characteristics of the item. " \ + "Please describe your inferred user preferences and item characteristics in the first person and in the following format:\n\nMy preferences: []\nThe item's characteristics: []\n\n" \ + "Note that your inference of the personalized preferences should not include any information about the title of the item." + + +preference_prompt_1 = "Suppose the user has bought a variety of {dataset_full_name} items, they are: \n{item_titles}. \nAs we all know, these historically purchased items serve as a reflection of the user's personalized preferences. " \ + "Please analyze the user's personalized preferences based on the items he has bought and provide a brief third-person summary of the user's preferences, highlighting the key factors that influence his choice of items. Avoid listing specific items and do not list multiple examples. " \ + "Your analysis should be brief and in the third person." + +preference_prompt_2 = "Given a chronological list of {dataset_full_name} items that a user has purchased, we can analyze his long-term and short-term preferences. Long-term preferences are inherent characteristics of the user, which are reflected in all the items he has interacted with over time. Short-term preferences are the user's recent preferences, which are reflected in some of the items he has bought more recently. " \ + "To determine the user's long-term preferences, please analyze the contents of all the items he has bought. Look for common features that appear frequently across the user's shopping records. To determine the user's short-term preferences, focus on the items he has bought most recently. Identify any new or different features that have emerged in the user's shopping records. " \ + "Here is a chronological list of items that the user has bought: \n{item_titles}. \nPlease provide separate analyses for the user's long-term and short-term preferences. Your answer should be concise and general, without listing specific items. Your answer should be in the third person and in the following format:\n\nLong-term preferences: []\nShort-term preferences: []\n\n" + + +# remove 'Magazine', 'Gift', 'Music', 'Kindle' +amazon18_dataset_list = [ + 'Appliances', 'Beauty', + 'Fashion', 'Software', 'Luxury', 'Scientific', 'Pantry', + 'Instruments', 'Arts', 'Games', 'Office', 'Garden', + 'Food', 'Cell', 'CDs', 'Automotive', 'Toys', + 'Pet', 'Tools', 'Kindle', 'Sports', 'Movies', + 'Electronics', 'Home', 'Clothing', 'Books' +] + +amazon18_dataset2fullname = { + 'Beauty': 'All_Beauty', + 'Fashion': 'AMAZON_FASHION', + 'Appliances': 'Appliances', + 'Arts': 'Arts_Crafts_and_Sewing', + 'Automotive': 'Automotive', + 'Books': 'Books', + 'CDs': 'CDs_and_Vinyl', + 'Cell': 'Cell_Phones_and_Accessories', + 'Clothing': 'Clothing_Shoes_and_Jewelry', + 'Music': 'Digital_Music', + 'Electronics': 'Electronics', + 'Gift': 'Gift_Cards', + 'Food': 'Grocery_and_Gourmet_Food', + 'Home': 'Home_and_Kitchen', + 'Scientific': 'Industrial_and_Scientific', + 'Kindle': 'Kindle_Store', + 'Luxury': 'Luxury_Beauty', + 'Magazine': 'Magazine_Subscriptions', + 'Movies': 'Movies_and_TV', + 'Instruments': 'Musical_Instruments', + 'Office': 'Office_Products', + 'Garden': 'Patio_Lawn_and_Garden', + 'Pet': 'Pet_Supplies', + 'Pantry': 'Prime_Pantry', + 'Software': 'Software', + 'Sports': 'Sports_and_Outdoors', + 'Tools': 'Tools_and_Home_Improvement', + 'Toys': 'Toys_and_Games', + 'Games': 'Video_Games' +} + +amazon14_dataset_list = [ + 'Beauty','Toys','Sports' +] + +amazon14_dataset2fullname = { + 'Beauty': 'Beauty', + 'Sports': 'Sports_and_Outdoors', + 'Toys': 'Toys_and_Games', +} + +# c1. c2. c3. c4. +amazon_text_feature1 = ['title', 'category', 'brand'] + +# re-order +amazon_text_feature1_ro1 = ['brand', 'main_cat', 'category', 'title'] + +# remove +amazon_text_feature1_re1 = ['title'] + +amazon_text_feature2 = ['title'] + +amazon_text_feature3 = ['description'] + +amazon_text_feature4 = ['description', 'main_cat', 'category', 'brand'] + +amazon_text_feature5 = ['title', 'description'] + + diff --git a/evaluate-finetuned.py b/evaluate-finetuned.py new file mode 100644 index 0000000000000000000000000000000000000000..78f9ebf387ffd7a4866bd7347e8ecfaa57cd0a80 --- /dev/null +++ b/evaluate-finetuned.py @@ -0,0 +1,197 @@ +import argparse +import json +import os +import sys + +import torch +import transformers +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel +from peft import PeftModel +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig + +from utils import * +from collator import TestCollator +from prompt import all_prompt +from evaluate import get_topk_results, get_metrics_results + +parser = argparse.ArgumentParser(description = 'rqllama-evaluate') +parser = parse_evaluate_args(parser) +args = parser.parse_args() + +set_seed(args.seed) +world_size = int(os.environ.get("WORLD_SIZE", 1)) +local_rank = int(os.environ.get("LOCAL_RANK") or 0) +torch.cuda.set_device(local_rank) +if local_rank == 0: + print(vars(args)) + +dist.init_process_group(backend = "nccl", world_size = world_size, rank = local_rank) + +device_map = {"": local_rank} +device = torch.device("cuda",local_rank) + +tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path) +base_model = LlamaForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.float16, low_cpu_mem_usage = True, device_map = device_map) +base_model.resize_token_embeddings(len(tokenizer)) +model = PeftModel.from_pretrained(base_model, args.ckpt_path, torch_dtype = torch.float16, device_map = device_map) + +model = DistributedDataParallel(model, device_ids = [local_rank]) + +if args.test_prompt_ids == "all": + if args.test_task.lower() == "seqrec": + prompt_ids = range(len(all_prompt["seqrec"])) + elif args.test_task.lower() == "itemsearch": + prompt_ids = range(len(all_prompt["itemsearch"])) + elif args.test_task.lower() == "fusionseqrec": + prompt_ids = range(len(all_prompt["fusionseqrec"])) +else: + prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")] + +test_data = load_test_dataset(args) +if local_rank == 0: + print("evaluate data num:", len(test_data)) +ddp_sampler = DistributedSampler(test_data, num_replicas = world_size, rank = local_rank, drop_last = True) +collator = TestCollator(args, tokenizer) +all_items = test_data.get_all_items() +prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(tokenizer) +test_loader = DataLoader( + test_data, + batch_size = args.test_batch_size, + collate_fn = collator, + sampler = ddp_sampler, + num_workers = 4, + pin_memory = True +) + +model.eval() + +metrics = args.metrics.split(",") +all_prompt_results = [] + +print('prompts:', len(prompt_ids)) + +with torch.no_grad(): + for prompt_id in prompt_ids: + if local_rank == 0: + print("Start prompt: ",prompt_id) + test_loader.dataset.set_prompt(prompt_id) + metrics_results = {} + total = 0 + + for step, batch in enumerate(tqdm(test_loader)): + inputs = batch[0].to(device) + targets = batch[1] + bs = len(targets) + num_beams = args.num_beams + + while True: + try: + output = model.module.generate( + input_ids = inputs["input_ids"], + attention_mask = inputs["attention_mask"], + max_new_tokens = 10, + prefix_allowed_tokens_fn = prefix_allowed_tokens, + num_beams = num_beams, + num_return_sequences = num_beams, + output_scores = True, + return_dict_in_generate = True, + early_stopping = True, + ) + break + except torch.cuda.OutOfMemoryError as e: + print("Out of memory!") + num_beams = num_beams -1 + print("Beam:", num_beams) + except Exception: + raise RuntimeError + output_ids = output["sequences"] + scores = output["sequences_scores"] + + # output_ids.shape: torch.Size([20, 101]) + # scores.shape: torch.Size([20]) + + output = tokenizer.batch_decode(output_ids, skip_special_tokens = True) + # output.length: 20 + ''' + Below is an instruction that describes a task. + Write a response that appropriately completes the request.\n\n + ### Instruction:\nThe user has interacted with items , , + , in chronological order. + Can you predict the next possible item that the user may expect?\n\n + ### Response: + ''' + + topk_res = get_topk_results( + output, + scores, + targets, + num_beams, + all_items = all_items if args.filter_items else None + ) + + bs_gather_list = [None for _ in range(world_size)] + dist.all_gather_object(obj=bs, object_list=bs_gather_list) + total += sum(bs_gather_list) + res_gather_list = [None for _ in range(world_size)] + dist.all_gather_object(obj=topk_res, object_list=res_gather_list) + + if local_rank == 0: + all_device_topk_res = [] + for ga_res in res_gather_list: + all_device_topk_res += ga_res + batch_metrics_res = get_metrics_results(all_device_topk_res, metrics) + for m, res in batch_metrics_res.items(): + if m not in metrics_results: + metrics_results[m] = res + else: + metrics_results[m] += res + + if (step + 1) % 50 == 0: + temp = {} + for m in metrics_results: + temp[m] = metrics_results[m] / total + print(temp) + dist.barrier() + + if local_rank == 0: + for m in metrics_results: + metrics_results[m] = metrics_results[m] / total + all_prompt_results.append(metrics_results) + print("======================================================") + print("Prompt {} results: ".format(prompt_id), metrics_results) + print("======================================================") + print("") + dist.barrier() +dist.barrier() + +if local_rank == 0: + mean_results = {} + min_results = {} + max_results = {} + + for m in metrics: + all_res = [_[m] for _ in all_prompt_results] + mean_results[m] = sum(all_res)/len(all_res) + min_results[m] = min(all_res) + max_results[m] = max(all_res) + + print("======================================================") + print("Mean results: ", mean_results) + print("Min results: ", min_results) + print("Max results: ", max_results) + print("======================================================") + + save_data={} + save_data["test_prompt_ids"] = args.test_prompt_ids + save_data["mean_results"] = mean_results + save_data["min_results"] = min_results + save_data["max_results"] = max_results + save_data["all_prompt_results"] = all_prompt_results + + with open(args.results_file, "w") as f: + json.dump(save_data, f, indent = 4) + print("Save file: ", args.results_file) diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..9a61f8bcaaa4a26a2e9438835dd403663b55e71e --- /dev/null +++ b/evaluate.py @@ -0,0 +1,69 @@ +import math + +def get_topk_results(predictions, scores, targets, k, all_items=None): + # target: [''] + results = [] + B = len(targets) + predictions = [_.split("Response:")[-1] for _ in predictions] + predictions = [_.strip().replace(" ","") for _ in predictions] + # prediction: ['', '', ''] + + if all_items is not None: + for i, seq in enumerate(predictions): + if seq not in all_items: + scores[i] = -1000 + + for b in range(B): + batch_seqs = predictions[b * k: (b + 1) * k] + batch_scores = scores[b * k: (b + 1) * k] + + pairs = [(a, b) for a, b in zip(batch_seqs, batch_scores)] + sorted_pairs = sorted(pairs, key=lambda x: x[1], reverse=True) + target_item = targets[b] + one_results = [] + for sorted_pred in sorted_pairs: + if sorted_pred[0] == target_item: + one_results.append(1) + else: + one_results.append(0) + + results.append(one_results) + + # result: [[0, 0, 0]] + return results + +def get_metrics_results(topk_results, metrics): + res = {} + for m in metrics: + if m.lower().startswith("hit"): + k = int(m.split("@")[1]) + res[m] = hit_k(topk_results, k) + elif m.lower().startswith("ndcg"): + k = int(m.split("@")[1]) + res[m] = ndcg_k(topk_results, k) + else: + raise NotImplementedError + + return res + + +def ndcg_k(topk_results, k): + + ndcg = 0.0 + for row in topk_results: + res = row[:k] + one_ndcg = 0.0 + for i in range(len(res)): + one_ndcg += res[i] / math.log(i + 2, 2) + ndcg += one_ndcg + return ndcg + + +def hit_k(topk_results, k): + hit = 0.0 + for row in topk_results: + res = row[:k] + if sum(res) > 0: + hit += 1 + return hit + diff --git a/fine-tune.py b/fine-tune.py new file mode 100644 index 0000000000000000000000000000000000000000..87810f265f4c1b0b06134ec469c9f964b79ba5f2 --- /dev/null +++ b/fine-tune.py @@ -0,0 +1,154 @@ +import argparse +import os +import sys +from typing import List + +import torch +import transformers +from peft import PeftModel +from peft import ( + TaskType, + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + set_peft_model_state_dict, +) +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig + +from utils import * +from collator import Collator + +import argparse +from utils import * +from rq_llama import * + +parser = argparse.ArgumentParser(description = 'rqllama-finetune') +parser = parse_finetune_args(parser) +args = parser.parse_args() + +set_seed(args.seed) +ensure_dir(args.output_dir) + +device_map = "auto" +world_size = int(os.environ.get("WORLD_SIZE", 1)) +ddp = world_size != 1 +local_rank = int(os.environ.get("LOCAL_RANK") or 0) +if local_rank == 0: + print(vars(args)) + +if ddp: + device_map = {"": local_rank} + +train_data, valid_data = load_finetune_datasets(args) + +rqllama = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map) +tokenizer = rqllama.tokenizer +# PeftModelForCausalLM +model = rqllama.model +device = rqllama.device + +postfix = '' +new_tokens = [] +new_ids = list(range(args.reindex)) +for i in new_ids: + new_tokens.append(postfix.format(int(i))) +tokenizer.add_tokens(new_tokens) + +if local_rank == 0: + print("token num:", len(rqllama.tokenizer)) + print("data num:", len(train_data)) + +collator = Collator(args, tokenizer) + +# Re-index Embedding +new_ids = torch.tensor(new_ids, dtype = torch.float16).reshape(-1,1) +re_index_emb = torch.nn.Linear(1, model.config.hidden_size, dtype = torch.float16).to(device) +new_embeddings = re_index_emb(new_ids.to(device)) +# PeftModelForCausalLM -> LlamaForCausalLM -> LlamaModel +model.model.model.embed_tokens.original_module.weight.data = torch.cat([model.model.model.embed_tokens.original_module.weight.data, new_embeddings], dim = 0) +model.model.model.embed_tokens.modules_to_save.default.weight.data = torch.cat([model.model.model.embed_tokens.modules_to_save.default.weight.data, new_embeddings], dim = 0) + +new_lm_head = torch.randn(args.reindex, model.config.hidden_size, requires_grad = True).to(device) +# print('new_lm_head:',new_lm_head.requires_grad) +# PeftModelForCausalLM -> LlamaForCausalLM +model.model.lm_head.original_module.weight.data = torch.cat([model.model.lm_head.original_module.weight.data, new_lm_head], dim = 0) +model.model.lm_head.modules_to_save.default.weight.data = torch.cat([model.model.lm_head.modules_to_save.default.weight.data, new_lm_head], dim = 0) + +model.config.vocab_size = len(tokenizer) + +# print(model.model.model.embed_tokens.original_module.weight.shape) +# print(len(tokenizer)) + +model.train() + +if local_rank == 0: + model.print_trainable_parameters() + +trainer = transformers.Trainer( + model = model, + train_dataset = train_data, + eval_dataset = valid_data, + args = transformers.TrainingArguments( + seed = args.seed, + per_device_train_batch_size = args.per_device_batch_size, + per_device_eval_batch_size = args.per_device_batch_size, + gradient_accumulation_steps = args.gradient_accumulation_steps, + warmup_ratio = args.warmup_ratio, + num_train_epochs = args.epochs, + learning_rate = args.learning_rate, + weight_decay = args.weight_decay, + lr_scheduler_type = args.lr_scheduler_type, + fp16 = args.fp16, + bf16 = args.bf16, + logging_steps = args.logging_step, + optim = args.optim, + gradient_checkpointing = True, + evaluation_strategy = args.save_and_eval_strategy, + save_strategy = args.save_and_eval_strategy, + eval_steps = args.save_and_eval_steps, + save_steps = args.save_and_eval_steps, + output_dir = args.output_dir, + save_total_limit = 50, + load_best_model_at_end = True, + deepspeed = args.deepspeed, + ddp_find_unused_parameters = False if ddp else None, + report_to = None, + eval_delay = 1 if args.save_and_eval_strategy=="epoch" else 2000, + dataloader_num_workers = args.dataloader_num_workers, + dataloader_prefetch_factor = args.dataloader_prefetch_factor, + remove_unused_columns = args.remove_unused_columns, + ), + tokenizer = tokenizer, + data_collator = collator, +) +model.config.use_cache = False + +if torch.__version__ >= "2" and sys.platform != "win32": + model = torch.compile(model) + +trainer.train(resume_from_checkpoint = args.resume_from_checkpoint) + +trainer.save_state() +trainer.save_model(output_dir = args.output_dir) + +if local_rank == 0: + print('rqllama fine-tune finished.') + + import smtplib + from email.mime.text import MIMEText + mail_host = 'smtp.qq.com' + mail_code = 'ouzplpngooqndjcb' + sender = '1849334588@qq.com' + receiver = 'esperanto1949@foxmail.com' + + task = '[v39: finetune twin-tower]' + message = MIMEText('Task {task} Finished'.format(task = task), 'plain', 'utf-8') + message['Subject'] = 'Auto Email' + message['From'] = sender + message['To'] = receiver + + server = smtplib.SMTP_SSL("smtp.qq.com", 465) + server.login(sender, mail_code) + server.sendmail(sender, receiver, message.as_string()) + + server.quit() \ No newline at end of file diff --git a/finetune.py b/finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..c39862ca7ef18068c2d9cc6c787d653cff4c229a --- /dev/null +++ b/finetune.py @@ -0,0 +1,119 @@ +import argparse +import os + +import sys +from typing import List + +import torch +import transformers + +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig + +from utils import * +from collator import Collator + +def train(args): + + set_seed(args.seed) + ensure_dir(args.output_dir) + + device_map = "auto" + world_size = int(os.environ.get("WORLD_SIZE", 1)) + ddp = world_size != 1 + local_rank = int(os.environ.get("LOCAL_RANK") or 0) + if local_rank == 0: + print(vars(args)) + + if ddp: + device_map = {"": local_rank} + + config = LlamaConfig.from_pretrained(args.base_model) + tokenizer = LlamaTokenizer.from_pretrained( + args.base_model, + model_max_length = args.model_max_length, + padding_side="right", + ) + tokenizer.pad_token_id = 0 + gradient_checkpointing = True + + train_data, valid_data = load_datasets(args) + add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens()) + config.vocab_size = len(tokenizer) + if local_rank == 0: + print("add {} new token.".format(add_num)) + print("data num:", len(train_data)) + tokenizer.save_pretrained(args.output_dir) + config.save_pretrained(args.output_dir) + + collator = Collator(args, tokenizer) + + + model = LlamaForCausalLM.from_pretrained( + args.base_model, + # torch_dtype=torch.float16, + device_map=device_map, + ) + model.resize_token_embeddings(len(tokenizer)) + + + if not ddp and torch.cuda.device_count() > 1: + model.is_parallelizable = True + model.model_parallel = True + + + trainer = transformers.Trainer( + model=model, + train_dataset=train_data, + eval_dataset=valid_data, + args=transformers.TrainingArguments( + seed=args.seed, + per_device_train_batch_size=args.per_device_batch_size, + per_device_eval_batch_size=args.per_device_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + warmup_ratio=args.warmup_ratio, + num_train_epochs=args.epochs, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + lr_scheduler_type=args.lr_scheduler_type, + fp16=args.fp16, + bf16=args.bf16, + logging_steps=args.logging_step, + optim=args.optim, + gradient_checkpointing=gradient_checkpointing, + evaluation_strategy=args.save_and_eval_strategy, + save_strategy=args.save_and_eval_strategy, + eval_steps=args.save_and_eval_steps, + save_steps=args.save_and_eval_steps, + output_dir=args.output_dir, + save_total_limit=5, + load_best_model_at_end=True, + deepspeed=args.deepspeed, + ddp_find_unused_parameters=False if ddp else None, + report_to=None, + eval_delay= 1 if args.save_and_eval_strategy=="epoch" else 2000, + ), + tokenizer=tokenizer, + data_collator=collator, + ) + model.config.use_cache = False + + + trainer.train( + resume_from_checkpoint=args.resume_from_checkpoint, + ) + + trainer.save_state() + trainer.save_model(output_dir=args.output_dir) + + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='LLMRec') + parser = parse_global_args(parser) + parser = parse_train_args(parser) + parser = parse_dataset_args(parser) + + args = parser.parse_args() + + train(args) diff --git a/generate_embeddings.py b/generate_embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..528f91d6193553e064bd7ad5a66345b35f726b78 --- /dev/null +++ b/generate_embeddings.py @@ -0,0 +1,86 @@ +import os +import collections +import json +import logging +import argparse +import numpy as np +import pandas as pd +import torch +from time import time +from torch import optim +from tqdm import tqdm +from torch.utils.data import DataLoader +from rq_llama import * + +def parse_args(): + parser = argparse.ArgumentParser(description = "Index") + parser.add_argument("--ckpt_path", type = str, default = "", help = "") + parser.add_argument("--item_save_path", type = str, default = "", help = "") + parser.add_argument("--user_save_path", type = str, default = "", help = "") + parser.add_argument("--device_map", type = str, default = "1", help = "gpu or cpu") + return parser.parse_args() + +args = parse_args() +print(args) +device_map = {'': int(args.device_map)} +MODEL = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map) +MODEL.eval() +device = MODEL.device +llama = MODEL.model.get_decoder() +tokenizer = MODEL.tokenizer +item_texts = MODEL.item_texts +user_texts = MODEL.user_texts + +all_idx = [] +all_embeddings = [] +with torch.no_grad(): + for idx, text in tqdm(item_texts.items()): + item_text = text['title'] + ' ' + text['description'] + item_ids = tokenizer(item_text, return_tensors = 'pt', padding = True, truncation = True).to(device) + item_emb = llama(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask) + item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1) + item_emb = item_emb.sum(dim = 1) / item_ids.attention_mask.sum(dim = -1, keepdim = True) + + all_idx.append(idx) + all_embeddings.append(item_emb.detach().cpu().numpy().flatten().tolist()) + +results = { + 'id': all_idx, + 'emb': [] +} + +for emb in tqdm(all_embeddings): + str_emb = '' + for e in emb: + str_emb = str_emb + str(e) + ' ' + results['emb'].append(str_emb[:-1]) + +df = pd.DataFrame(results) +df.to_csv(args.item_save_path, sep = '\t', header = 0, index = False) + +all_idx = [] +all_embeddings = [] +with torch.no_grad(): + for idx, text in tqdm(user_texts.items()): + user_text = ' '.join(text) + user_ids = tokenizer(user_text, return_tensors = 'pt', padding = True, truncation = True).to(device) + user_emb = llama(input_ids = user_ids.input_ids, attention_mask = user_ids.attention_mask) + user_emb = user_emb.last_hidden_state * user_ids.attention_mask.unsqueeze(-1) + user_emb = user_emb.sum(dim = 1) / user_ids.attention_mask.sum(dim = -1, keepdim = True) + + all_idx.append(idx) + all_embeddings.append(user_emb.detach().cpu().numpy().flatten().tolist()) + +results = { + 'id': all_idx, + 'emb': [] +} + +for emb in tqdm(all_embeddings): + str_emb = '' + for e in emb: + str_emb = str_emb + str(e) + ' ' + results['emb'].append(str_emb[:-1]) + +df = pd.DataFrame(results) +df.to_csv(args.user_save_path, sep = '\t', header = 0, index = False) \ No newline at end of file diff --git a/generate_indices.py b/generate_indices.py new file mode 100644 index 0000000000000000000000000000000000000000..5c982c0aa56c51bea99317571ea4976af836f4f5 --- /dev/null +++ b/generate_indices.py @@ -0,0 +1,150 @@ +import os +import collections +import json +import logging +import argparse +import numpy as np +import pandas as pd +import torch +from time import time +from torch import optim +from tqdm import tqdm +from torch.utils.data import DataLoader + +from rq_llama import * +from index.datasets import EmbDataset + +def if_collided(all_indices_str): + tot_item = len(all_indices_str) + tot_indice = len(set(all_indices_str.tolist())) + return tot_item == tot_indice + +def get_indices_count(all_indices_str): + indices_count = collections.defaultdict(int) + for index in all_indices_str: + indices_count[index] += 1 + return indices_count + +def get_collision_item(all_indices_str): + index2id = {} + for i, index in enumerate(all_indices_str): + if index not in index2id: + index2id[index] = [] + index2id[index].append(i) + collision_item_groups = [] + for index in index2id: + if len(index2id[index]) > 1: + collision_item_groups.append(index2id[index]) + return collision_item_groups + +def parse_args(): + parser = argparse.ArgumentParser(description = "Index") + parser.add_argument("--ckpt_path", type = str, default = "", help = "") + parser.add_argument("--item_data_path", type = str, default = "", help = "") + parser.add_argument("--user_data_path", type = str, default = "", help = "") + parser.add_argument("--save_path", type = str, default = "", help = "") + parser.add_argument("--device_map", type = str, default = "1", help = "gpu or cpu") + return parser.parse_args() + +args = parse_args() +print(args) + +device_map = {'': int(args.device_map)} +MODEL = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map) +MODEL.eval() +device = MODEL.device +postfix = '' + +data = EmbDataset(args.item_data_path) +data_loader = DataLoader(data, num_workers = 4, batch_size = 64, shuffle = False, pin_memory = True) +rqvae = MODEL.item_rqvae +prefix = MODEL.prefix + +index_table = {} +all_indices = [] +all_indices_str = [] +with torch.no_grad(): + for x in tqdm(data_loader): + indices = rqvae.get_indices(x.to(device), False) + indices = indices.view(-1, indices.shape[-1]).cpu().numpy() + for index in indices: + code = [] + for i, ind in enumerate(index): + code.append(prefix[i].format(int(ind))) + + if str(code) in index_table: + index_table[str(code)] += 1 + else: + index_table[str(code)] = 0 + code.append(postfix.format(index_table[str(code)])) + + all_indices.append(code) + all_indices_str.append(str(code)) + +all_indices = np.array(all_indices) +all_indices_str = np.array(all_indices_str) + +print("All indices number: ", len(all_indices)) +print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values())) +print('Re-index number:', max(index_table.values())) + +all_indices_dict = {} +for item, indices in enumerate(all_indices.tolist()): + all_indices_dict[item] = list(indices) + +reindex_dict = {'reindex': max(index_table.values())} + +json_path = os.path.join(args.save_path,'indices.item.json') +with open(json_path, 'w',encoding = 'utf-8') as f: + json.dump(all_indices_dict, f) + +reindex_path = os.path.join(args.save_path,'reindex.item.json') +with open(reindex_path, 'w',encoding = 'utf-8') as f: + json.dump(reindex_dict, f) + +data = EmbDataset(args.user_data_path) +data_loader = DataLoader(data, num_workers = 4, batch_size = 64, shuffle = False, pin_memory = True) +rqvae = MODEL.user_rqvae +prefix = MODEL.user_prefix + +# index_table = {} +all_indices = [] +all_indices_str = [] +with torch.no_grad(): + for x in tqdm(data_loader): + indices = rqvae.get_indices(x.to(device), False) + indices = indices.view(-1, indices.shape[-1]).cpu().numpy() + for index in indices: + code = [] + for i, ind in enumerate(index): + code.append(prefix[i].format(int(ind))) + + # if str(code) in index_table: + # index_table[str(code)] += 1 + # else: + # index_table[str(code)] = 0 + # code.append(postfix.format(index_table[str(code)])) + + all_indices.append(code) + all_indices_str.append(str(code)) + +all_indices = np.array(all_indices) +all_indices_str = np.array(all_indices_str) + +print("All indices number: ", len(all_indices)) +print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values())) +# print('Re-index number:', max(index_table.values())) + +all_indices_dict = {} +for item, indices in enumerate(all_indices.tolist()): + all_indices_dict[item] = list(indices) + +# reindex_dict = {'reindex': max(index_table.values())} + +json_path = os.path.join(args.save_path,'indices.user.json') +with open(json_path, 'w',encoding = 'utf-8') as f: + json.dump(all_indices_dict, f) + +# reindex_path = os.path.join(args.save_path,'reindex.user.json') +# with open(reindex_path, 'w',encoding = 'utf-8') as f: +# json.dump(reindex_dict, f) \ No newline at end of file diff --git a/generate_random_indices.py b/generate_random_indices.py new file mode 100644 index 0000000000000000000000000000000000000000..557a3b0661e7b7acba3eff4334cc8084aa7ab7fc --- /dev/null +++ b/generate_random_indices.py @@ -0,0 +1,198 @@ +import os +import collections +import json +import logging +import argparse +import numpy as np +import pandas as pd +import torch +from time import time +from torch import optim +from tqdm import tqdm +import torch.utils.data as data +from torch.utils.data import DataLoader +from index.models.rqvae import RQVAE +# from rq_llama import * +# from index.datasets import EmbDataset +import random + +class NpyDataset(data.Dataset): + def __init__(self, data_path): + self.data_path = data_path + self.embeddings = np.load(data_path) + self.dim = self.embeddings.shape[-1] + + def __getitem__(self, index): + emb = self.embeddings[index] + tensor_emb = torch.FloatTensor(emb) + return tensor_emb + + def __len__(self): + return len(self.embeddings) + +def if_collided(all_indices_str): + tot_item = len(all_indices_str) + tot_indice = len(set(all_indices_str.tolist())) + return tot_item == tot_indice + +def get_indices_count(all_indices_str): + indices_count = collections.defaultdict(int) + for index in all_indices_str: + indices_count[index] += 1 + return indices_count + +def get_collision_item(all_indices_str): + index2id = {} + for i, index in enumerate(all_indices_str): + if index not in index2id: + index2id[index] = [] + index2id[index].append(i) + collision_item_groups = [] + for index in index2id: + if len(index2id[index]) > 1: + collision_item_groups.append(index2id[index]) + return collision_item_groups + +def parse_args(): + parser = argparse.ArgumentParser(description = "Index") + parser.add_argument("--item_model_path", type = str, default = "", help = "") + parser.add_argument("--item_data_path", type = str, default = "", help = "") + parser.add_argument("--user_model_path", type = str, default = "", help = "") + parser.add_argument("--user_data_path", type = str, default = "", help = "") + # parser.add_argument("--save_path", type = str, default = "", help = "") + parser.add_argument("--device", type = str, default = "cuda:0", help = "gpu or cpu") + return parser.parse_args() + +generate_args = parse_args() +print(generate_args) + +device = torch.device(generate_args.device) + +# generate item index +ckpt = torch.load(os.path.join(generate_args.item_model_path, 'best_collision_model.pth'), map_location = torch.device('cpu')) +args = ckpt['args'] +state_dict = ckpt['state_dict'] + +data = NpyDataset(generate_args.item_data_path) +data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True) +# model = RQVAE( +# in_dim = data.dim, +# num_emb_list = args.num_emb_list, +# e_dim = args.e_dim, +# layers = args.layers, +# dropout_prob = args.dropout_prob, +# bn = args.bn, +# loss_type = args.loss_type, +# quant_loss_weight = args.quant_loss_weight, +# kmeans_init = args.kmeans_init, +# kmeans_iters = args.kmeans_iters, +# sk_epsilons = args.sk_epsilons, +# sk_iters = args.sk_iters, +# ) +# model.load_state_dict(state_dict) +# model = model.to(device) +# model.eval() +# print(model) + +prefix = ["","","","",""] +postfix = "" + +index_table = {} +all_indices = [] +all_indices_str = [] +with torch.no_grad(): + for x in tqdm(data_loader): + # indices = model.get_indices(x.to(device), False) + # indices = indices.view(-1, indices.shape[-1]).cpu().numpy() + + indices = np.random.randint(0, 256, size = (64, 4), dtype = int) + for index in indices: + code = [] + for i, ind in enumerate(index): + code.append(prefix[i].format(int(ind))) + + if str(code) in index_table: + index_table[str(code)] += 1 + else: + index_table[str(code)] = 0 + code.append(postfix.format(index_table[str(code)])) + + all_indices.append(code) + all_indices_str.append(str(code)) + +all_indices = np.array(all_indices) +all_indices_str = np.array(all_indices_str) + +print("All indices number: ", len(all_indices)) +print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values())) +print('Re-index number:', max(index_table.values())) + +all_indices_dict = {} +for item, indices in enumerate(all_indices.tolist()): + all_indices_dict[item] = list(indices) +reindex_dict = {'reindex': max(index_table.values())} + +item_index_path = os.path.join(generate_args.item_model_path, 'indices.random.item.json') +with open(item_index_path, 'w', encoding = 'utf-8') as f: + json.dump(all_indices_dict, f) + +item_reindex_path = os.path.join(generate_args.item_model_path, 'reindex.random.item.json') +with open(item_reindex_path, 'w', encoding = 'utf-8') as f: + json.dump(reindex_dict, f) + +# generate user index +ckpt = torch.load(os.path.join(generate_args.user_model_path, 'best_collision_model.pth'), map_location = torch.device('cpu')) +args = ckpt['args'] +state_dict = ckpt['state_dict'] + +data = NpyDataset(generate_args.user_data_path) +data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True) +# model = RQVAE( +# in_dim = data.dim, +# num_emb_list = args.num_emb_list, +# e_dim = args.e_dim, +# layers = args.layers, +# dropout_prob = args.dropout_prob, +# bn = args.bn, +# loss_type = args.loss_type, +# quant_loss_weight = args.quant_loss_weight, +# kmeans_init = args.kmeans_init, +# kmeans_iters = args.kmeans_iters, +# sk_epsilons = args.sk_epsilons, +# sk_iters = args.sk_iters, +# ) +# model.load_state_dict(state_dict) +# model = model.to(device) +# model.eval() +# print(model) + +prefix = ['','','','',''] + +all_indices = [] +all_indices_str = [] +with torch.no_grad(): + for x in tqdm(data_loader): + # indices = rqvae.get_indices(x.to(device), False) + # indices = indices.view(-1, indices.shape[-1]).cpu().numpy() + indices = np.random.randint(0, 256, size = (64, 4), dtype = int) + for index in indices: + code = [] + for i, ind in enumerate(index): + code.append(prefix[i].format(int(ind))) + + all_indices.append(code) + all_indices_str.append(str(code)) + +all_indices = np.array(all_indices) +all_indices_str = np.array(all_indices_str) + +print("All indices number: ", len(all_indices)) +print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values())) + +all_indices_dict = {} +for item, indices in enumerate(all_indices.tolist()): + all_indices_dict[item] = list(indices) + +json_path = os.path.join(generate_args.user_model_path, 'indices.random.user.json') +with open(json_path, 'w', encoding = 'utf-8') as f: + json.dump(all_indices_dict, f) \ No newline at end of file diff --git a/generate_static_indices.py b/generate_static_indices.py new file mode 100644 index 0000000000000000000000000000000000000000..1751db4c318ce8a64a73f3830453252046ecdfbe --- /dev/null +++ b/generate_static_indices.py @@ -0,0 +1,193 @@ +import os +import collections +import json +import logging +import argparse +import numpy as np +import pandas as pd +import torch +from time import time +from torch import optim +from tqdm import tqdm +from torch.utils.data import DataLoader +from index.models.rqvae import RQVAE +# from rq_llama import * +# from index.datasets import EmbDataset + +class NpyDataset(data.Dataset): + def __init__(self, data_path): + self.data_path = data_path + self.embeddings = np.load(data_path) + self.dim = self.embeddings.shape[-1] + + def __getitem__(self, index): + emb = self.embeddings[index] + tensor_emb = torch.FloatTensor(emb) + return tensor_emb + + def __len__(self): + return len(self.embeddings) + +def if_collided(all_indices_str): + tot_item = len(all_indices_str) + tot_indice = len(set(all_indices_str.tolist())) + return tot_item == tot_indice + +def get_indices_count(all_indices_str): + indices_count = collections.defaultdict(int) + for index in all_indices_str: + indices_count[index] += 1 + return indices_count + +def get_collision_item(all_indices_str): + index2id = {} + for i, index in enumerate(all_indices_str): + if index not in index2id: + index2id[index] = [] + index2id[index].append(i) + collision_item_groups = [] + for index in index2id: + if len(index2id[index]) > 1: + collision_item_groups.append(index2id[index]) + return collision_item_groups + +def parse_args(): + parser = argparse.ArgumentParser(description = "Index") + parser.add_argument("--item_model_path", type = str, default = "", help = "") + parser.add_argument("--item_data_path", type = str, default = "", help = "") + parser.add_argument("--user_model_path", type = str, default = "", help = "") + parser.add_argument("--user_data_path", type = str, default = "", help = "") + # parser.add_argument("--save_path", type = str, default = "", help = "") + parser.add_argument("--device", type = str, default = "cuda:0", help = "gpu or cpu") + return parser.parse_args() + +generate_args = parse_args() +print(generate_args) + +device = torch.device(generate_args.device) + +# generate item index +ckpt = torch.load(generate_args.item_model_path, map_location = torch.device('cpu')) +args = ckpt['args'] +state_dict = ckpt['state_dict'] + +data = NpyDataset(generate_args.item_data_path) +data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True) +model = RQVAE( + in_dim = data.dim, + num_emb_list = args.num_emb_list, + e_dim = args.e_dim, + layers = args.layers, + dropout_prob = args.dropout_prob, + bn = args.bn, + loss_type = args.loss_type, + quant_loss_weight = args.quant_loss_weight, + kmeans_init = args.kmeans_init, + kmeans_iters = args.kmeans_iters, + sk_epsilons = args.sk_epsilons, + sk_iters = args.sk_iters, +) +model.load_state_dict(state_dict) +model = model.to(device) +model.eval() +# print(model) + +prefix = ["","","","",""] +postfix = "" + +index_table = {} +all_indices = [] +all_indices_str = [] +with torch.no_grad(): + for x in tqdm(data_loader): + indices = model.get_indices(x.to(device), False) + indices = indices.view(-1, indices.shape[-1]).cpu().numpy() + for index in indices: + code = [] + for i, ind in enumerate(index): + code.append(prefix[i].format(int(ind))) + + if str(code) in index_table: + index_table[str(code)] += 1 + else: + index_table[str(code)] = 0 + code.append(postfix.format(index_table[str(code)])) + + all_indices.append(code) + all_indices_str.append(str(code)) + +all_indices = np.array(all_indices) +all_indices_str = np.array(all_indices_str) + +print("All indices number: ", len(all_indices)) +print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values())) +print('Re-index number:', max(index_table.values())) + +all_indices_dict = {} +for item, indices in enumerate(all_indices.tolist()): + all_indices_dict[item] = list(indices) +reindex_dict = {'reindex': max(index_table.values())} + +item_index_path = os.path.join(generate_args.item_model_path, 'indices.item.json') +with open(json_path, 'w', encoding = 'utf-8') as f: + json.dump(all_indices_dict, f) + +item_reindex_path = os.path.join(generate_args.item_model_path, 'reindex.item.json') +with open(reindex_path, 'w', encoding = 'utf-8') as f: + json.dump(reindex_dict, f) + +# generate user index +ckpt = torch.load(generate_args.user_model_path, map_location = torch.device('cpu')) +args = ckpt['args'] +state_dict = ckpt['state_dict'] + +data = NpyDataset(generate_args.user_data_path) +data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True) +model = RQVAE( + in_dim = data.dim, + num_emb_list = args.num_emb_list, + e_dim = args.e_dim, + layers = args.layers, + dropout_prob = args.dropout_prob, + bn = args.bn, + loss_type = args.loss_type, + quant_loss_weight = args.quant_loss_weight, + kmeans_init = args.kmeans_init, + kmeans_iters = args.kmeans_iters, + sk_epsilons = args.sk_epsilons, + sk_iters = args.sk_iters, +) +model.load_state_dict(state_dict) +model = model.to(device) +model.eval() +# print(model) + +prefix = ['','','','',''] + +all_indices = [] +all_indices_str = [] +with torch.no_grad(): + for x in tqdm(data_loader): + indices = rqvae.get_indices(x.to(device), False) + indices = indices.view(-1, indices.shape[-1]).cpu().numpy() + for index in indices: + code = [] + for i, ind in enumerate(index): + code.append(prefix[i].format(int(ind))) + + all_indices.append(code) + all_indices_str.append(str(code)) + +all_indices = np.array(all_indices) +all_indices_str = np.array(all_indices_str) + +print("All indices number: ", len(all_indices)) +print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values())) + +all_indices_dict = {} +for item, indices in enumerate(all_indices.tolist()): + all_indices_dict[item] = list(indices) + +json_path = os.path.join(generate_args.user_model_path, 'indices.user.json') +with open(json_path, 'w', encoding = 'utf-8') as f: + json.dump(all_indices_dict, f) \ No newline at end of file diff --git a/index/datasets.py b/index/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..14c3ddec98d25b9c0f9d8a91eb5fa376d25a6be5 --- /dev/null +++ b/index/datasets.py @@ -0,0 +1,27 @@ +import numpy as np +import torch +import torch.utils.data as data +import pandas as pd +from tqdm import tqdm + +class EmbDataset(data.Dataset): + def __init__(self,data_path): + self.data_path = data_path + names = ['emb'] + usecols = [1] + tsv_data = pd.read_csv(data_path, sep = '\t',usecols = usecols, names = names, quotechar = None, quoting = 3) + features = tsv_data['emb'].values.tolist() + num_data = len(features) + for i in tqdm(range(num_data)): + features[i] = [float(s) for s in features[i].split(' ')] + self.embeddings = np.array(features, dtype = np.float16) + assert self.embeddings.shape[0] == num_data + self.dim = self.embeddings.shape[-1] + + def __getitem__(self, index): + emb = self.embeddings[index] + tensor_emb = torch.tensor(emb, dtype = torch.float16) + return tensor_emb + + def __len__(self): + return len(self.embeddings) diff --git a/index/generate_indices.py b/index/generate_indices.py new file mode 100644 index 0000000000000000000000000000000000000000..bd2220e8d7c58b933c1f4e1ceaa448f185fb43a4 --- /dev/null +++ b/index/generate_indices.py @@ -0,0 +1,151 @@ +import collections +import json +import logging +import argparse + +import numpy as np +import torch +from time import time +from torch import optim +from tqdm import tqdm + +from torch.utils.data import DataLoader + +from datasets import EmbDataset +from models.rqvae import RQVAE + +import os + +def check_collision(all_indices_str): + tot_item = len(all_indices_str) + tot_indice = len(set(all_indices_str.tolist())) + return tot_item==tot_indice + +def get_indices_count(all_indices_str): + indices_count = collections.defaultdict(int) + for index in all_indices_str: + indices_count[index] += 1 + return indices_count + +def get_collision_item(all_indices_str): + index2id = {} + for i, index in enumerate(all_indices_str): + if index not in index2id: + index2id[index] = [] + index2id[index].append(i) + + collision_item_groups = [] + + for index in index2id: + if len(index2id[index]) > 1: + collision_item_groups.append(index2id[index]) + + return collision_item_groups + +def parse_args(): + parser = argparse.ArgumentParser(description = "Index") + + parser.add_argument("--data_path", type = str, default = "", help = "Infer data path.") + parser.add_argument("--ckpt_path", type=str, default="", help="model checkpoint for infer") + parser.add_argument("--id_save_path", type=str, default="", help="output directory for id result") + parser.add_argument("--device", type=str, default="cuda:0", help="gpu or cpu") + + return parser.parse_args() + +# dataset = "Games" +# ckpt_path = "/zhengbowen/rqvae_ckpt/xxxx" +# output_dir = f"/zhengbowen/data/{dataset}/" +# output_file = f"{dataset}.index.json" +# output_file = os.path.join(output_dir,output_file) + +infer_args = parse_args() +print('infer_args:', infer_args) +device = torch.device(infer_args.device) +output_file = infer_args.id_save_path +data = EmbDataset(infer_args.data_path) + +ckpt = torch.load(infer_args.ckpt_path, map_location = torch.device('cpu')) +args = ckpt["args"] +state_dict = ckpt["state_dict"] + +model = RQVAE(in_dim=data.dim, + num_emb_list=args.num_emb_list, + e_dim=args.e_dim, + layers=args.layers, + dropout_prob=args.dropout_prob, + bn=args.bn, + loss_type=args.loss_type, + quant_loss_weight=args.quant_loss_weight, + kmeans_init=args.kmeans_init, + kmeans_iters=args.kmeans_iters, + sk_epsilons=args.sk_epsilons, + sk_iters=args.sk_iters, + ) + +model.load_state_dict(state_dict) +model = model.to(device) +model.eval() +print(model) + +data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True) + +all_indices = [] +all_indices_str = [] +prefix = ["","","","",""] + +for d in tqdm(data_loader): + d = d.to(device) + indices = model.get_indices(d,use_sk = False) + indices = indices.view(-1, indices.shape[-1]).cpu().numpy() + for index in indices: + code = [] + for i, ind in enumerate(index): + code.append(prefix[i].format(int(ind))) + + all_indices.append(code) + all_indices_str.append(str(code)) + +all_indices = np.array(all_indices) +all_indices_str = np.array(all_indices_str) + +for vq in model.rq.vq_layers[:-1]: + vq.sk_epsilon = 0.0 +if model.rq.vq_layers[-1].sk_epsilon == 0.0: + model.rq.vq_layers[-1].sk_epsilon = 0.003 + +tt = 0 +#There are often duplicate items in the dataset, and we no longer differentiate them +while True: + if tt >= 20 or check_collision(all_indices_str): + break + + collision_item_groups = get_collision_item(all_indices_str) + # print(collision_item_groups) + print(len(collision_item_groups)) + for collision_items in collision_item_groups: + d = data[collision_items].to(device) + + indices = model.get_indices(d, use_sk= True) + indices = indices.view(-1, indices.shape[-1]).cpu().numpy() + for item, index in zip(collision_items, indices): + code = [] + for i, ind in enumerate(index): + code.append(prefix[i].format(int(ind))) + + all_indices[item] = code + all_indices_str[item] = str(code) + tt += 1 + +print("All indices number: ", len(all_indices)) +print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values())) + +tot_item = len(all_indices_str) +tot_indice = len(set(all_indices_str.tolist())) +print("Collision Rate", (tot_item - tot_indice) / tot_item) + +all_indices_dict = {} +for item, indices in enumerate(all_indices.tolist()): + all_indices_dict[item] = list(indices) + +with open(output_file, 'w') as fp: + json.dump(all_indices_dict, fp) diff --git a/index/main.py b/index/main.py new file mode 100644 index 0000000000000000000000000000000000000000..d6718ff95140ea304fd3302c34524ecb6fb6e8a5 --- /dev/null +++ b/index/main.py @@ -0,0 +1,87 @@ +import argparse +import random +import torch +import numpy as np +from time import time +import logging + +from torch.utils.data import DataLoader + +from datasets import EmbDataset +from models.rqvae import RQVAE +from trainer import Trainer + +def parse_args(): + parser = argparse.ArgumentParser(description="Index") + + parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') + parser.add_argument('--epochs', type=int, default=5000, help='number of epochs') + parser.add_argument('--batch_size', type=int, default=1024, help='batch size') + parser.add_argument('--num_workers', type=int, default=4, ) + parser.add_argument('--eval_step', type=int, default=50, help='eval step') + parser.add_argument('--learner', type=str, default="AdamW", help='optimizer') + parser.add_argument("--data_path", type=str, + default="../data/Games/Games.emb-llama-td.npy", + help="Input data path.") + + parser.add_argument('--weight_decay', type=float, default=1e-4, help='l2 regularization weight') + parser.add_argument("--dropout_prob", type=float, default=0.0, help="dropout ratio") + parser.add_argument("--bn", type=bool, default=False, help="use bn or not") + parser.add_argument("--loss_type", type=str, default="mse", help="loss_type") + parser.add_argument("--kmeans_init", type=bool, default=True, help="use kmeans_init or not") + parser.add_argument("--kmeans_iters", type=int, default=100, help="max kmeans iters") + parser.add_argument('--sk_epsilons', type=float, nargs='+', default=[0.0, 0.0, 0.0], help="sinkhorn epsilons") + parser.add_argument("--sk_iters", type=int, default=50, help="max sinkhorn iters") + + parser.add_argument("--device", type=str, default="cuda:1", help="gpu or cpu") + + parser.add_argument('--num_emb_list', type=int, nargs='+', default=[256,256,256], help='emb num of every vq') + parser.add_argument('--e_dim', type=int, default=32, help='vq codebook embedding size') + parser.add_argument('--quant_loss_weight', type=float, default=1.0, help='vq quantion loss weight') + parser.add_argument('--layers', type=int, nargs='+', default=[2048,1024,512,256,128,64], help='hidden sizes of every layer') + + parser.add_argument("--ckpt_dir", type=str, default="", help="output directory for model") + + return parser.parse_args() + + +if __name__ == '__main__': + """fix the random seed""" + seed = 2023 + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + args = parse_args() + print(args) + + logging.basicConfig(level=logging.DEBUG) + + """build dataset""" + data = EmbDataset(args.data_path) + model = RQVAE(in_dim=data.dim, + num_emb_list=args.num_emb_list, + e_dim=args.e_dim, + layers=args.layers, + dropout_prob=args.dropout_prob, + bn=args.bn, + loss_type=args.loss_type, + quant_loss_weight=args.quant_loss_weight, + kmeans_init=args.kmeans_init, + kmeans_iters=args.kmeans_iters, + sk_epsilons=args.sk_epsilons, + sk_iters=args.sk_iters, + ) + print(model) + data_loader = DataLoader(data,num_workers=args.num_workers, + batch_size=args.batch_size, shuffle=True, + pin_memory=True) + trainer = Trainer(args,model) + best_loss, best_collision_rate = trainer.fit(data_loader) + + print("Best Loss",best_loss) + print("Best Collision Rate", best_collision_rate) + diff --git a/index/models/layers.py b/index/models/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..47836c9d1f3cc7ce13281fcd575abc64a06f1930 --- /dev/null +++ b/index/models/layers.py @@ -0,0 +1,106 @@ +import torch +import torch.nn as nn +from torch.nn.init import xavier_normal_ +from sklearn.cluster import KMeans + + +class MLPLayers(nn.Module): + + def __init__( + self, layers, dropout=0.0, activation="relu", bn=False + ): + super(MLPLayers, self).__init__() + self.layers = layers + self.dropout = dropout + self.activation = activation + self.use_bn = bn + + mlp_modules = [] + for idx, (input_size, output_size) in enumerate( + zip(self.layers[:-1], self.layers[1:]) + ): + mlp_modules.append(nn.Dropout(p=self.dropout)) + mlp_modules.append(nn.Linear(input_size, output_size)) + if self.use_bn: + mlp_modules.append(nn.BatchNorm1d(num_features=output_size)) + activation_func = activation_layer(self.activation, output_size) + if activation_func is not None and idx != (len(self.layers)-2): + mlp_modules.append(activation_func) + + self.mlp_layers = nn.Sequential(*mlp_modules) + self.apply(self.init_weights) + + def init_weights(self, module): + # We just initialize the module with normal distribution as the paper said + if isinstance(module, nn.Linear): + xavier_normal_(module.weight.data) + if module.bias is not None: + module.bias.data.fill_(0.0) + + def forward(self, input_feature): + return self.mlp_layers(input_feature) + +def activation_layer(activation_name="relu", emb_dim=None): + + if activation_name is None: + activation = None + elif isinstance(activation_name, str): + if activation_name.lower() == "sigmoid": + activation = nn.Sigmoid() + elif activation_name.lower() == "tanh": + activation = nn.Tanh() + elif activation_name.lower() == "relu": + activation = nn.ReLU() + elif activation_name.lower() == "leakyrelu": + activation = nn.LeakyReLU() + elif activation_name.lower() == "none": + activation = None + elif issubclass(activation_name, nn.Module): + activation = activation_name() + else: + raise NotImplementedError( + "activation function {} is not implemented".format(activation_name) + ) + + return activation + +def kmeans( + samples, + num_clusters, + num_iters = 10, +): + B, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device + x = samples.cpu().detach().numpy() + + cluster = KMeans(n_clusters = num_clusters, max_iter = num_iters).fit(x) + + centers = cluster.cluster_centers_ + tensor_centers = torch.from_numpy(centers).to(device) + + return tensor_centers + + +@torch.no_grad() +def sinkhorn_algorithm(distances, epsilon, sinkhorn_iterations): + Q = torch.exp(- distances / epsilon) + + B = Q.shape[0] # number of samples to assign + K = Q.shape[1] # how many centroids per block (usually set to 256) + + # make the matrix sums to 1 + sum_Q = Q.sum(-1, keepdim=True).sum(-2, keepdim=True) + Q /= sum_Q + # print(Q.sum()) + for it in range(sinkhorn_iterations): + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=1, keepdim=True) + Q /= B + + # normalize each row: total weight per prototype must be 1/K + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= K + + + Q *= B # the colomns must sum to 1 so that Q is an assignment + return Q \ No newline at end of file diff --git a/index/models/rq.py b/index/models/rq.py new file mode 100644 index 0000000000000000000000000000000000000000..bbb7a32734759bca6b5a9867855f7cf5d445048c --- /dev/null +++ b/index/models/rq.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn + +from .vq import VectorQuantizer + + +class ResidualVectorQuantizer(nn.Module): + """ References: + SoundStream: An End-to-End Neural Audio Codec + https://arxiv.org/pdf/2107.03312.pdf + """ + + def __init__(self, n_e_list, e_dim, sk_epsilons, + kmeans_init = False, kmeans_iters = 100, sk_iters=100,): + super().__init__() + self.n_e_list = n_e_list + self.e_dim = e_dim + self.num_quantizers = len(n_e_list) + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.sk_epsilons = sk_epsilons + self.sk_iters = sk_iters + self.vq_layers = nn.ModuleList([VectorQuantizer(n_e, e_dim, + kmeans_init = self.kmeans_init, + kmeans_iters = self.kmeans_iters, + sk_epsilon=sk_epsilon, + sk_iters=sk_iters) + for n_e, sk_epsilon in zip(n_e_list,sk_epsilons) ]) + + def get_codebook(self): + all_codebook = [] + for quantizer in self.vq_layers: + codebook = quantizer.get_codebook() + all_codebook.append(codebook) + return torch.stack(all_codebook) + + def forward(self, x, use_sk=True): + all_losses = [] + all_indices = [] + + x_q = 0 + residual = x + for quantizer in self.vq_layers: + x_res, loss, indices = quantizer(residual, use_sk=use_sk) + residual = residual - x_res + x_q = x_q + x_res + + all_losses.append(loss) + all_indices.append(indices) + + mean_losses = torch.stack(all_losses).mean() + all_indices = torch.stack(all_indices, dim=-1) + + return x_q, mean_losses, all_indices \ No newline at end of file diff --git a/index/models/rqvae.py b/index/models/rqvae.py new file mode 100644 index 0000000000000000000000000000000000000000..74b58ca40b986bd91620c5fd50c9ce51cafe8788 --- /dev/null +++ b/index/models/rqvae.py @@ -0,0 +1,84 @@ +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from .layers import MLPLayers +from .rq import ResidualVectorQuantizer + + +class RQVAE(nn.Module): + def __init__(self, + in_dim=768, + # num_emb_list=[256,256,256,256], + num_emb_list=None, + e_dim=64, + # layers=[512,256,128], + layers=None, + dropout_prob=0.0, + bn=False, + loss_type="mse", + quant_loss_weight=1.0, + kmeans_init=False, + kmeans_iters=100, + # sk_epsilons=[0,0,0.003,0.01]], + sk_epsilons=None, + sk_iters=100, + ): + super(RQVAE, self).__init__() + + self.in_dim = in_dim + self.num_emb_list = num_emb_list + self.e_dim = e_dim + + self.layers = layers + self.dropout_prob = dropout_prob + self.bn = bn + self.loss_type = loss_type + self.quant_loss_weight=quant_loss_weight + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.sk_epsilons = sk_epsilons + self.sk_iters = sk_iters + + self.encode_layer_dims = [self.in_dim] + self.layers + [self.e_dim] + self.encoder = MLPLayers(layers=self.encode_layer_dims, + dropout=self.dropout_prob,bn=self.bn) + + self.rq = ResidualVectorQuantizer(num_emb_list, e_dim, + kmeans_init = self.kmeans_init, + kmeans_iters = self.kmeans_iters, + sk_epsilons=self.sk_epsilons, + sk_iters=self.sk_iters,) + + self.decode_layer_dims = self.encode_layer_dims[::-1] + self.decoder = MLPLayers(layers=self.decode_layer_dims, + dropout=self.dropout_prob,bn=self.bn) + + def forward(self, x, use_sk=True): + # print('x.shape:',x.shape) + x = self.encoder(x) + x_q, rq_loss, indices = self.rq(x,use_sk=use_sk) + out = self.decoder(x_q) + # print('out.shape:',out.shape) + + return out, rq_loss, indices + + @torch.no_grad() + def get_indices(self, xs, use_sk=False): + x_e = self.encoder(xs) + _, _, indices = self.rq(x_e, use_sk=use_sk) + return indices + + def compute_loss(self, out, quant_loss, xs=None): + + if self.loss_type == 'mse': + loss_recon = F.mse_loss(out, xs, reduction='mean') + elif self.loss_type == 'l1': + loss_recon = F.l1_loss(out, xs, reduction='mean') + else: + raise ValueError('incompatible loss type') + + loss_total = loss_recon + self.quant_loss_weight * quant_loss + + return loss_total, loss_recon \ No newline at end of file diff --git a/index/models/vq.py b/index/models/vq.py new file mode 100644 index 0000000000000000000000000000000000000000..bd622380ba4a88897839c92655e99bc47cdf996a --- /dev/null +++ b/index/models/vq.py @@ -0,0 +1,104 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .layers import kmeans, sinkhorn_algorithm + + +class VectorQuantizer(nn.Module): + + def __init__(self, n_e, e_dim, + beta = 0.25, kmeans_init = False, kmeans_iters = 10, + sk_epsilon=0.01, sk_iters=100): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.kmeans_init = kmeans_init + self.kmeans_iters = kmeans_iters + self.sk_epsilon = sk_epsilon + self.sk_iters = sk_iters + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + if not kmeans_init: + self.initted = True + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + else: + self.initted = False + self.embedding.weight.data.zero_() + + def get_codebook(self): + return self.embedding.weight + + def get_codebook_entry(self, indices, shape=None): + # get quantized latent vectors + z_q = self.embedding(indices) + if shape is not None: + z_q = z_q.view(shape) + + return z_q + + def init_emb(self, data): + + centers = kmeans( + data, + self.n_e, + self.kmeans_iters, + ) + + self.embedding.weight.data.copy_(centers) + self.initted = True + + @staticmethod + def center_distance_for_constraint(distances): + # distances: B, K + max_distance = distances.max() + min_distance = distances.min() + + middle = (max_distance + min_distance) / 2 + amplitude = max_distance - middle + 1e-5 + assert amplitude > 0 + centered_distances = (distances - middle) / amplitude + return centered_distances + + def forward(self, x, use_sk=True): + # Flatten input + latent = x.view(-1, self.e_dim) + + if not self.initted and self.training: + self.init_emb(latent) + + # Calculate the L2 Norm between latent and Embedded weights + d = torch.sum(latent**2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1, keepdim=True).t()- \ + 2 * torch.matmul(latent, self.embedding.weight.t()) + if not use_sk or self.sk_epsilon <= 0: + indices = torch.argmin(d, dim=-1) + # print("=======",self.sk_epsilon) + else: + # print("++++++++",self.sk_epsilon) + d = self.center_distance_for_constraint(d) + d = d.double() + Q = sinkhorn_algorithm(d,self.sk_epsilon,self.sk_iters) + # print(Q.sum(0)[:10]) + Q = torch.nan_to_num(Q, Q[torch.isfinite(Q)].min().item()) + if torch.isnan(Q).any() or torch.isinf(Q).any(): + print(f"Sinkhorn Algorithm returns nan/inf values.") + indices = torch.argmax(Q, dim=-1) + + # indices = torch.argmin(d, dim=-1) + + x_q = self.embedding(indices).view(x.shape) + + # compute loss for embedding + commitment_loss = F.mse_loss(x_q.detach(), x) + codebook_loss = F.mse_loss(x_q, x.detach()) + loss = codebook_loss + self.beta * commitment_loss + + # preserve gradients + x_q = x + (x_q - x).detach() + + indices = indices.view(x.shape[:-1]) + + return x_q, loss, indices + + diff --git a/index/run.sh b/index/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..7462b6ca9622206de7085b67f9221d425332984d --- /dev/null +++ b/index/run.sh @@ -0,0 +1,8 @@ + +python -u main.py \ + --num_emb_list 256 256 256 256 \ + --sk_epsilons 0.0 0.0 0.0 0.003 \ + --device cuda:0 \ + --data_path /data/Games/Games.emb-llama-td.npy \ + --batch_size 1024 + diff --git a/index/trainer.py b/index/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4d5b7f8fbb6637a1f6932101402b72ec923db9 --- /dev/null +++ b/index/trainer.py @@ -0,0 +1,209 @@ +import logging + +import numpy as np +import torch +from time import time +from torch import optim +from tqdm import tqdm + +from utils import ensure_dir,set_color,get_local_time +import os + +class Trainer(object): + + def __init__(self, args, model): + self.args = args + self.model = model + self.logger = logging.getLogger() + + self.lr = args.lr + self.learner = args.learner + self.weight_decay = args.weight_decay + self.epochs = args.epochs + self.eval_step = min(args.eval_step, self.epochs) + self.device = args.device + self.device = torch.device(self.device) + self.ckpt_dir = args.ckpt_dir + saved_model_dir = "{}".format(get_local_time()) + self.ckpt_dir = os.path.join(self.ckpt_dir,saved_model_dir) + ensure_dir(self.ckpt_dir) + + self.best_loss = np.inf + self.best_collision_rate = np.inf + self.best_loss_ckpt = "best_loss_model.pth" + self.best_collision_ckpt = "best_collision_model.pth" + self.optimizer = self._build_optimizer() + self.model = self.model.to(self.device) + + def _build_optimizer(self): + + params = self.model.parameters() + learner = self.learner + learning_rate = self.lr + weight_decay = self.weight_decay + + if learner.lower() == "adam": + optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay) + elif learner.lower() == "sgd": + optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay) + elif learner.lower() == "adagrad": + optimizer = optim.Adagrad( + params, lr=learning_rate, weight_decay=weight_decay + ) + for state in optimizer.state.values(): + for k, v in state.items(): + if torch.is_tensor(v): + state[k] = v.to(self.device) + elif learner.lower() == "rmsprop": + optimizer = optim.RMSprop( + params, lr=learning_rate, weight_decay=weight_decay + ) + elif learner.lower() == 'adamw': + optimizer = optim.AdamW( + params, lr=learning_rate, weight_decay=weight_decay + ) + else: + self.logger.warning( + "Received unrecognized optimizer, set default Adam optimizer" + ) + optimizer = optim.Adam(params, lr=learning_rate) + return optimizer + def _check_nan(self, loss): + if torch.isnan(loss): + raise ValueError("Training loss is nan") + + def _train_epoch(self, train_data, epoch_idx): + + self.model.train() + + total_loss = 0 + total_recon_loss = 0 + iter_data = tqdm( + train_data, + total=len(train_data), + ncols=100, + desc=set_color(f"Train {epoch_idx}","pink"), + ) + + for batch_idx, data in enumerate(iter_data): + data = data.to(self.device) + self.optimizer.zero_grad() + out, rq_loss, indices = self.model(data) + loss, loss_recon = self.model.compute_loss(out, rq_loss, xs=data) + self._check_nan(loss) + loss.backward() + self.optimizer.step() + total_loss += loss.item() + total_recon_loss += loss_recon.item() + + return total_loss, total_recon_loss + + @torch.no_grad() + def _valid_epoch(self, valid_data): + + self.model.eval() + + iter_data =tqdm( + valid_data, + total=len(valid_data), + ncols=100, + desc=set_color(f"Evaluate ", "pink"), + ) + indices_set = set() + num_sample = 0 + for batch_idx, data in enumerate(iter_data): + num_sample += len(data) + data = data.to(self.device) + indices = self.model.get_indices(data) + indices = indices.view(-1,indices.shape[-1]).cpu().numpy() + for index in indices: + code = "-".join([str(int(_)) for _ in index]) + indices_set.add(code) + + collision_rate = (num_sample - len(indices_set))/num_sample + + return collision_rate + + def _save_checkpoint(self, epoch, collision_rate=1, ckpt_file=None): + + ckpt_path = os.path.join(self.ckpt_dir,ckpt_file) if ckpt_file \ + else os.path.join(self.ckpt_dir, 'epoch_%d_collision_%.4f_model.pth' % (epoch, collision_rate)) + state = { + "args": self.args, + "epoch": epoch, + "best_loss": self.best_loss, + "best_collision_rate": self.best_collision_rate, + "state_dict": self.model.state_dict(), + "optimizer": self.optimizer.state_dict(), + } + torch.save(state, ckpt_path, pickle_protocol=4) + + self.logger.info( + set_color("Saving current", "blue") + f": {ckpt_path}" + ) + + def _generate_train_loss_output(self, epoch_idx, s_time, e_time, loss, recon_loss): + train_loss_output = ( + set_color("epoch %d training", "green") + + " [" + + set_color("time", "blue") + + ": %.2fs, " + ) % (epoch_idx, e_time - s_time) + train_loss_output += set_color("train loss", "blue") + ": %.4f" % loss + train_loss_output +=", " + train_loss_output += set_color("reconstruction loss", "blue") + ": %.4f" % recon_loss + return train_loss_output + "]" + + + def fit(self, data): + + cur_eval_step = 0 + + for epoch_idx in range(self.epochs): + # train + training_start_time = time() + train_loss, train_recon_loss = self._train_epoch(data, epoch_idx) + training_end_time = time() + train_loss_output = self._generate_train_loss_output( + epoch_idx, training_start_time, training_end_time, train_loss, train_recon_loss + ) + self.logger.info(train_loss_output) + + if train_loss < self.best_loss: + self.best_loss = train_loss + # self._save_checkpoint(epoch=epoch_idx,ckpt_file=self.best_loss_ckpt) + + # eval + if (epoch_idx + 1) % self.eval_step == 0: + valid_start_time = time() + collision_rate = self._valid_epoch(data) + + if collision_rate < self.best_collision_rate: + self.best_collision_rate = collision_rate + cur_eval_step = 0 + self._save_checkpoint(epoch_idx, collision_rate=collision_rate, + ckpt_file=self.best_collision_ckpt) + else: + cur_eval_step += 1 + + + valid_end_time = time() + valid_score_output = ( + set_color("epoch %d evaluating", "green") + + " [" + + set_color("time", "blue") + + ": %.2fs, " + + set_color("collision_rate", "blue") + + ": %f]" + ) % (epoch_idx, valid_end_time - valid_start_time, collision_rate) + + self.logger.info(valid_score_output) + if epoch_idx>1000: + self._save_checkpoint(epoch_idx, collision_rate=collision_rate) + + + return self.best_loss, self.best_collision_rate + + + + diff --git a/index/utils.py b/index/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..15464c7a6d1a975e26fdc6b0430d24edf0722596 --- /dev/null +++ b/index/utils.py @@ -0,0 +1,36 @@ + +import datetime +import os + + +def ensure_dir(dir_path): + + os.makedirs(dir_path, exist_ok=True) + +def set_color(log, color, highlight=True): + color_set = ["black", "red", "green", "yellow", "blue", "pink", "cyan", "white"] + try: + index = color_set.index(color) + except: + index = len(color_set) - 1 + prev_log = "\033[" + if highlight: + prev_log += "1;3" + else: + prev_log += "0;3" + prev_log += str(index) + "m" + return prev_log + log + "\033[0m" + +def get_local_time(): + r"""Get current time + + Returns: + str: current time + """ + cur = datetime.datetime.now() + cur = cur.strftime("%b-%d-%Y_%H-%M-%S") + + return cur + + + diff --git a/infer.sh b/infer.sh new file mode 100644 index 0000000000000000000000000000000000000000..f07dd270086daec75ceed40d22b00bf5de3f7528 --- /dev/null +++ b/infer.sh @@ -0,0 +1,14 @@ +CKPT_PATH=${datain}/v-yinju/rq-llama/v6/Instruments + +python generate_embeddings.py \ + --ckpt_path $CKPT_PATH \ + --item_save_path $CKPT_PATH/embeddings.item.tsv \ + --user_save_path $CKPT_PATH/embeddings.user.tsv \ + --device_map 0 + +python generate_indices.py \ + --ckpt_path $CKPT_PATH \ + --item_data_path $CKPT_PATH/embeddings.item.tsv \ + --user_data_path $CKPT_PATH/embeddings.user.tsv \ + --save_path $CKPT_PATH \ + --device_map 0 \ No newline at end of file diff --git a/instruments_evaluate.sh b/instruments_evaluate.sh new file mode 100644 index 0000000000000000000000000000000000000000..1b37aee42fffae6fe0dcc1660ec988e033c9500d --- /dev/null +++ b/instruments_evaluate.sh @@ -0,0 +1,18 @@ +DATASET=Instruments +BASE_MODEL=/datain/v-yinju/llama-7b +DATA_PATH=/datain/v-yinju/rqvae-zzx/data +CKPT_PATH=/datain/v-yinju/rq-llama/v11.2/Ins/finetune +RESULTS_FILE=$CKPT_PATH/eval_result.json + +torchrun --nproc_per_node=8 evaluate-finetuned.py \ + --base_model $BASE_MODEL \ + --ckpt_path $CKPT_PATH \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --results_file $RESULTS_FILE \ + --test_batch_size 1 \ + --num_beams 20 \ + --test_prompt_ids all \ + --test_task seqrec \ + --index_file /datain/v-yinju/rq-llama/v11.2/Ins/indices.item.json \ + --user_index_file /datain/v-yinju/rq-llama/v11.2/Ins/indices.user.json \ No newline at end of file diff --git a/instruments_finetune.sh b/instruments_finetune.sh new file mode 100644 index 0000000000000000000000000000000000000000..8bd9cc84e15ba937198750f26d1f0585c0d43999 --- /dev/null +++ b/instruments_finetune.sh @@ -0,0 +1,33 @@ +export WANDB_MODE=disabled +export CUDA_LAUNCH_BLOCKING=0 + +DATASET=Instruments +CKPT_PATH=/datain/v-yinju/rq-llama/v11/Instruments +DATA_PATH=/datain/v-yinju/rqvae-zzx/data +OUTPUT_DIR=$CKPT_PATH/finetune + +torchrun --nproc_per_node=8 fine-tune.py \ + --ckpt_path $CKPT_PATH \ + --output_dir $OUTPUT_DIR \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --per_device_batch_size 6 \ + --gradient_accumulation_steps 2 \ + --learning_rate 5e-5 \ + --epochs 4 \ + --weight_decay 0.01 \ + --save_and_eval_strategy epoch \ + --fp16 \ + --deepspeed ./config/ds_z2_fp16.json \ + --dataloader_num_workers 4 \ + --only_train_response \ + --tasks seqrec,itemsearch,preferenceobtain,item2index,index2item,fusionseqrec,usersearch,user2pref,pref2user \ + --train_prompt_sample_num 1,1,1,1,1,1,1,1,1 \ + --train_data_sample_num 0,0,0,0,0,0,0,0,0 \ + --index_file $CKPT_PATH/indices.item.json \ + --user_index_file $CKPT_PATH/indices.user.json \ + --reindex 17 + +cd convert +nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 & +cd .. \ No newline at end of file diff --git a/instruments_more_pretrain.sh b/instruments_more_pretrain.sh new file mode 100644 index 0000000000000000000000000000000000000000..f2387587acf9ac17093eea5bd6baeb722cc3515b --- /dev/null +++ b/instruments_more_pretrain.sh @@ -0,0 +1,32 @@ +export WANDB_MODE=disabled +export CUDA_LAUNCH_BLOCKING=0 + +DATASET=Instruments +BASE_MODEL=/datain/v-yinju/llama-7b +CKPT_PATH=/datain/v-yinju/rq-llama/v6/Instruments +DATA_PATH=/datain/v-yinju/rqvae-zzx/data +OUTPUT_DIR=/datain/v-yinju/rq-llama/v3-train/Instruments/more_pretrain + +torchrun --nproc_per_node=8 --master_port=3324 continue_pretrain.py \ + --base_model $BASE_MODEL \ + --ckpt_path $CKPT_PATH \ + --output_dir $OUTPUT_DIR \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --per_device_batch_size 6 \ + --gradient_accumulation_steps 2 \ + --learning_rate 5e-5 \ + --epochs 4 \ + --weight_decay 0.01 \ + --save_and_eval_strategy epoch \ + --deepspeed ./config/ds_z2_fp16.json \ + --dataloader_num_workers 4 \ + --only_train_response \ + --tasks seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item \ + --train_prompt_sample_num 1,1,1,1,1,1,1,1,1 \ + --train_data_sample_num 0,0,0,0,0,0,0,0,0 \ + --fp16 &>>$OUTPUT_DIR/pretrain-log.txt + +cd convert +nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 & +cd .. \ No newline at end of file diff --git a/instruments_pretrain.sh b/instruments_pretrain.sh new file mode 100644 index 0000000000000000000000000000000000000000..87aecf0cd5d8f0c5e4e53992d231342913b805c1 --- /dev/null +++ b/instruments_pretrain.sh @@ -0,0 +1,85 @@ +export WANDB_MODE=disabled +export CUDA_LAUNCH_BLOCKING=0 + +DATASET=Instruments +BASE_MODEL=$datain/v-yinju/llama-7b +ITEM_MODEL=$datain/v-yinju/rqvae-zzx/models/instruments/Apr-01-2024_01-25-11/best_collision_model.pth +USER_MODEL=$datain/v-yinju/rqvae-zzx/models/instruments/user/Apr-23-2024_03-36-04/best_collision_model.pth +DATA_PATH=$datain/v-yinju/rqvae-zzx/data +OUTPUT_DIR=$datain/v-yinju/rq-llama/v11.2/Ins + +torchrun --nproc_per_node=8 pre-train.py \ + --base_model $BASE_MODEL \ + --item_model $ITEM_MODEL \ + --user_model $USER_MODEL \ + --output_dir $OUTPUT_DIR \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --per_device_batch_size 6 \ + --gradient_accumulation_steps 2 \ + --learning_rate 5e-4 \ + --epochs 4 \ + --weight_decay 0.01 \ + --save_and_eval_strategy epoch \ + --deepspeed ./config/ds_z2_fp16.json \ + --dataloader_num_workers 4 \ + --only_train_response \ + --tasks seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item,usersearch,user2pref,pref2user \ + --train_prompt_sample_num 1,1,1,1,1,1,1,1,1,1,1,1 \ + --train_data_sample_num 0,0,0,0,0,0,0,0,0,0,0,0 \ + --index_file .index.json \ + --user_index_file .user-index.json \ + --fp16 + +cd convert +nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 & +cd .. + +CKPT_PATH=$datain/v-yinju/rq-llama/v11.2/Ins + +python generate_embeddings.py \ + --ckpt_path $CKPT_PATH \ + --item_save_path $CKPT_PATH/embeddings.item.tsv \ + --user_save_path $CKPT_PATH/embeddings.user.tsv \ + --device_map 0 + +python generate_indices.py \ + --ckpt_path $CKPT_PATH \ + --item_data_path $CKPT_PATH/embeddings.item.tsv \ + --user_data_path $CKPT_PATH/embeddings.user.tsv \ + --save_path $CKPT_PATH \ + --device_map 0 + +# DATASET=Games +# BASE_MODEL=/datain/v-yinju/llama-7b +# ITEM_MODEL=/datain/v-yinju/rqvae-zzx/models/games/Apr-18-2024_01-51-46/best_collision_model.pth +# USER_MODEL=/datain/v-yinju/rqvae-zzx/models/games/user/Jun-17-2024_18-40-36/best_collision_model.pth +# DATA_PATH=/datain/v-yinju/rqvae-zzx/data +# OUTPUT_DIR=/datain/v-yinju/rq-llama/v11/Games + +# torchrun --nproc_per_node=8 pre-train.py \ +# --base_model $BASE_MODEL \ +# --item_model $ITEM_MODEL \ +# --user_model $USER_MODEL \ +# --output_dir $OUTPUT_DIR \ +# --dataset $DATASET \ +# --data_path $DATA_PATH \ +# --per_device_batch_size 6 \ +# --gradient_accumulation_steps 2 \ +# --learning_rate 5e-5 \ +# --epochs 4 \ +# --weight_decay 0.01 \ +# --save_and_eval_strategy epoch \ +# --deepspeed ./config/ds_z2_fp16.json \ +# --dataloader_num_workers 4 \ +# --only_train_response \ +# --tasks seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item,usersearch,user2pref,pref2user \ +# --train_prompt_sample_num 1,1,1,1,1,1,1,1,1,1,1,1 \ +# --train_data_sample_num 0,0,0,0,0,0,0,0,0,0,0,0 \ +# --index_file .index.json \ +# --user_index_file .user-index.json \ +# --fp16 + +# cd convert +# nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 & +# cd .. \ No newline at end of file diff --git a/lora_finetune.py b/lora_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..612faaeeb1ec9780c446b819a9b94064a4a60214 --- /dev/null +++ b/lora_finetune.py @@ -0,0 +1,162 @@ +import argparse +import os +import sys +from typing import List + +import torch +import transformers + + +from peft import ( + TaskType, + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + set_peft_model_state_dict, +) +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig + +from utils import * +from collator import Collator + +def train(args): + + set_seed(args.seed) + ensure_dir(args.output_dir) + + device_map = "auto" + world_size = int(os.environ.get("WORLD_SIZE", 1)) + ddp = world_size != 1 + local_rank = int(os.environ.get("LOCAL_RANK") or 0) + if local_rank == 0: + print(vars(args)) + + if ddp: + device_map = {"": local_rank} + + config = LlamaConfig.from_pretrained(args.base_model) + tokenizer = LlamaTokenizer.from_pretrained( + args.base_model, + model_max_length=args.model_max_length, + padding_side="right", + ) + tokenizer.pad_token_id = 0 + + train_data, valid_data = load_datasets(args) + add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens()) + config.vocab_size = len(tokenizer) + if local_rank == 0: + print("add {} new token.".format(add_num)) + print("data num:", len(train_data)) + tokenizer.save_pretrained(args.output_dir) + config.save_pretrained(args.output_dir) + + collator = Collator(args, tokenizer) + + model = LlamaForCausalLM.from_pretrained( + args.base_model, + # torch_dtype=torch.float16, + device_map=device_map, + ) + model.resize_token_embeddings(len(tokenizer)) + + config = LoraConfig( + r=args.lora_r, + lora_alpha=args.lora_alpha, + target_modules=args.lora_target_modules.split(","), + modules_to_save=args.lora_modules_to_save.split(","), + lora_dropout=args.lora_dropout, + bias="none", + inference_mode=False, + task_type=TaskType.CAUSAL_LM, + ) + model = get_peft_model(model, config) + + if args.resume_from_checkpoint: + checkpoint_name = os.path.join( + args.resume_from_checkpoint, "adapter_model.bin" + ) # only LoRA model - LoRA config above has to fit + args.resume_from_checkpoint = False # So the trainer won't try loading its state + # The two files above have a different name depending on how they were saved, but are actually the same. + if os.path.exists(checkpoint_name): + if local_rank == 0: + print(f"Restarting from {checkpoint_name}") + adapters_weights = torch.load(checkpoint_name) + model = set_peft_model_state_dict(model, adapters_weights) + else: + if local_rank == 0: + print(f"Checkpoint {checkpoint_name} not found") + + for n, p in model.named_parameters(): + if "original_module" in n and any(module_name in n for module_name in config.modules_to_save): + p.requires_grad = False + + if local_rank == 0: + model.print_trainable_parameters() + + + if not ddp and torch.cuda.device_count() > 1: + model.is_parallelizable = True + model.model_parallel = True + + trainer = transformers.Trainer( + model=model, + train_dataset=train_data, + eval_dataset=valid_data, + args=transformers.TrainingArguments( + seed=args.seed, + per_device_train_batch_size=args.per_device_batch_size, + per_device_eval_batch_size=args.per_device_batch_size, + gradient_accumulation_steps=args.gradient_accumulation_steps, + warmup_ratio=args.warmup_ratio, + num_train_epochs=args.epochs, + learning_rate=args.learning_rate, + weight_decay=args.weight_decay, + lr_scheduler_type=args.lr_scheduler_type, + fp16=args.fp16, + bf16=args.bf16, + logging_steps=args.logging_step, + optim=args.optim, + gradient_checkpointing=True, + evaluation_strategy=args.save_and_eval_strategy, + save_strategy=args.save_and_eval_strategy, + eval_steps=args.save_and_eval_steps, + save_steps=args.save_and_eval_steps, + output_dir=args.output_dir, + save_total_limit=5, + load_best_model_at_end=True, + deepspeed=args.deepspeed, + ddp_find_unused_parameters=False if ddp else None, + report_to=None, + eval_delay=1 if args.save_and_eval_strategy=="epoch" else 2000, + ), + tokenizer=tokenizer, + data_collator=collator, + ) + model.config.use_cache = False + + # old_state_dict = model.state_dict + # model.state_dict = ( + # lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) + # ).__get__(model, type(model)) + + if torch.__version__ >= "2" and sys.platform != "win32": + model = torch.compile(model) + + trainer.train( + resume_from_checkpoint=args.resume_from_checkpoint, + ) + + trainer.save_state() + trainer.save_model(output_dir=args.output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='LLMRec') + parser = parse_global_args(parser) + parser = parse_train_args(parser) + parser = parse_dataset_args(parser) + + args = parser.parse_args() + + train(args) diff --git a/pre-train.py b/pre-train.py new file mode 100644 index 0000000000000000000000000000000000000000..22c9475fcca1303f4c45b0fed78aa9f9ec95ac2d --- /dev/null +++ b/pre-train.py @@ -0,0 +1,154 @@ +import os +import sys +from typing import List +import argparse + +import wandb +import torch +import transformers +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig + +from peft import ( + TaskType, + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + set_peft_model_state_dict, +) + +from collator import VanillaCollator +from rq_llama import * +from utils import * + +parser = argparse.ArgumentParser(description = 'rqllama-pretrain') +parser = parse_global_args(parser) +parser = parse_train_args(parser) +parser = parse_dataset_args(parser) +parser = parse_rqvae_args(parser) +args = parser.parse_args() +wandb.init(config = args, reinit = True) + +set_seed(args.seed) +ensure_dir(args.output_dir) + +device_map = "auto" +world_size = int(os.environ.get("WORLD_SIZE", 1)) +ddp = world_size != 1 +local_rank = int(os.environ.get("LOCAL_RANK") or 0) +if local_rank == 0: + print(vars(args)) +if ddp: + device_map = {"": local_rank} + +train_data, valid_data = load_datasets(args) + +config = LlamaConfig.from_pretrained(args.base_model) +config.args = vars(args) +rqllama = LlamaWithRQ(config) + +ckpt = torch.load(args.item_model, map_location = torch.device('cpu')) +state_dict = ckpt["state_dict"] +rqllama.item_rqvae.load_state_dict(state_dict) +for i in range(len(args.num_emb_list)): + rqllama.item_rqvae.rq.vq_layers[i].initted = True +ckpt = torch.load(args.user_model, map_location = torch.device('cpu')) +state_dict = ckpt["state_dict"] +rqllama.user_rqvae.load_state_dict(state_dict) +for i in range(len(args.num_emb_list)): + rqllama.user_rqvae.rq.vq_layers[i].initted = True + +if local_rank == 0: + print("token num:", len(rqllama.tokenizer)) + print("data num:", len(train_data)) + rqllama.tokenizer.save_pretrained(args.output_dir) + rqllama.config.save_pretrained(args.output_dir) + +if args.resume_from_checkpoint: + checkpoint_name = os.path.join(args.resume_from_checkpoint, "adapter_model.bin") + args.resume_from_checkpoint = False + if os.path.exists(checkpoint_name): + if local_rank == 0: + print(f"Restarting from {checkpoint_name}") + adapters_weights = torch.load(checkpoint_name) + rqllama.model = set_peft_model_state_dict(rqllama.model, adapters_weights) + else: + if local_rank == 0: + print(f"Checkpoint {checkpoint_name} not found") + +if local_rank == 0: + rqllama.model.print_trainable_parameters() + +if not ddp and torch.cuda.device_count() > 1: + rqllama.is_parallelizable = True + rqllama.model_parallel = True + +collator = VanillaCollator(args, rqllama.tokenizer) + +trainer = transformers.Trainer( + model = rqllama, + train_dataset = train_data, + eval_dataset = valid_data, + args = transformers.TrainingArguments( + seed = args.seed, + per_device_train_batch_size = args.per_device_batch_size, + per_device_eval_batch_size = args.per_device_batch_size, + gradient_accumulation_steps = args.gradient_accumulation_steps, + warmup_ratio = args.warmup_ratio, + num_train_epochs = args.epochs, + learning_rate = args.learning_rate, + weight_decay = args.weight_decay, + lr_scheduler_type = args.lr_scheduler_type, + fp16 = args.fp16, + bf16 = args.bf16, + logging_steps = args.logging_step, + optim = args.optim, + gradient_checkpointing = True, + evaluation_strategy = args.save_and_eval_strategy, + save_strategy = args.save_and_eval_strategy, + eval_steps = args.save_and_eval_steps, + save_steps = args.save_and_eval_steps, + output_dir = args.output_dir, + save_total_limit = 5, + load_best_model_at_end = True, + deepspeed = args.deepspeed, + ddp_find_unused_parameters = False if ddp else None, + report_to = None, + eval_delay = 1 if args.save_and_eval_strategy=="epoch" else 2000, + dataloader_num_workers = args.dataloader_num_workers, + dataloader_prefetch_factor = args.dataloader_prefetch_factor, + remove_unused_columns = args.remove_unused_columns, + ), + tokenizer = rqllama.tokenizer, + data_collator = collator, +) +rqllama.config.use_cache = False + +if torch.__version__ >= "2" and sys.platform != "win32": + rqllama = torch.compile(rqllama) + +trainer.train(resume_from_checkpoint = args.resume_from_checkpoint) + +trainer.save_state() +trainer.save_model(output_dir = args.output_dir) + +if local_rank == 0: + # print('rqllama pre-train finished.') + + import smtplib + from email.mime.text import MIMEText + mail_host = 'smtp.qq.com' + mail_code = 'ouzplpngooqndjcb' + sender = '1849334588@qq.com' + receiver = 'esperanto1949@foxmail.com' + + task = '[v53: pretrain tt.ins.5e-4 w/o projector]' + message = MIMEText('Task {task} Finished'.format(task = task), 'plain', 'utf-8') + message['Subject'] = 'Auto Email' + message['From'] = sender + message['To'] = receiver + + server = smtplib.SMTP_SSL("smtp.qq.com", 465) + server.login(sender, mail_code) + server.sendmail(sender, receiver, message.as_string()) + + server.quit() \ No newline at end of file diff --git a/prompt.py b/prompt.py new file mode 100644 index 0000000000000000000000000000000000000000..d5bda06619048d64a0ddd0673cc33a27c0b45965 --- /dev/null +++ b/prompt.py @@ -0,0 +1,811 @@ +sft_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." \ + "\n\n### Instruction:\n{instruction}\n\n### Response:{response}" + +all_prompt = {} + +# ===================================================== +# Task 12 -- User2Preference -- 8 Prompt +# ===================================================== + +user2pref_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "What is the preference of user {user}?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Would you mind informing me about the preference of user {user}?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Briefly summarize the preference of user {user}." +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Can you share with me the preference description corresponding to user {user}?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "How to describe the preference of user {user}?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "I need to know the preference of user {user}. Could you help me with that?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "Please provide a description of user {user}'s preference." +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Can you provide the corresponding description for user {user}'s preference?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +all_prompt["user2pref"] = user2pref_prompt + +# ===================================================== +# Task 11 -- Preference2User -- 8 Prompt +# ===================================================== + +pref2user_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "A user has the following preference: \"{preference}\". Which user is it describing?" +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Can you tell which user has such a preference: \"{preference}\"?" +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Your task is to determine the corresponding user based on his preference. Here is the preference description: \"{preference}\"." +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Please identify the user from the provided preference: \"{preference}\"." +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "Which user has the following preference: \"{preference}\"?" +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "Given the textual description of someone's preference as \"{preference}\", identify the corresponding user." +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "Based on the provided preference \"{preference}\", answer which user is it referring to?" +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Which user can be characterized by the following description: \"{preference}\"?" +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +all_prompt["pref2user"] = pref2user_prompt + +# ===================================================== +# Task 10 -- Possible User Prediction -- 10 Prompt +# ===================================================== + +usersearch_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "The item {item} has been historically clicked by users {users}. Can you predict another possible user which will click this item?" +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Given the item {item} and its historical interactive users {users}, I want to know which user will click this item. Please provide a reccomendation." +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Here are the item {item}'s historical interactive users: {users}, try to predict another user that is fond of this item." +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "According to the users {users} that have clicked the item {item}, can you determine the next possible user wanting the same item?" +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "By analyzing the item {item}'s historical interactions with users {users}, who is the next expected interactive user?" +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "After clicked by these users {users}, who is the next user that may be keen on the item {item}?" +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "Who is the most potential user for the given item {item} that has previously interacted with the following users {users}?" +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Taking the item {item}'s historical interactions as condition, predict the next user that may highly enjoy the same item. Here is the historical interactive users {users}." +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "You have access to the item {item}'s historical user interaction record {users}. Now your task is to predict another possible user that loves the same item based on the past interaction." +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "Considering the fact that several users {users} have clicked the same item {item}, forecast who is the next user that will be insterested in this item." +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +all_prompt["usersearch"] = usersearch_prompt + +# ===================================================== +# Task 1 -- Sequential Recommendation -- 17 Prompt +# ===================================================== + +seqrec_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "The user {user} has interacted with items {inters} in chronological order. Can you predict the next possible item that the user may expect?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "I find the historical interactive items of user {user}: {inters}, and I want to know what next item the user needs. Can you help me decide?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Here are the user {user}'s historical interactions: {inters}, try to recommend another item to the user. Note that the historical interactions are arranged in chronological order." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Based on the items that the user {user} has interacted with: {inters}, can you determine what item would be recommended to him next?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "The user {user} has interacted with the following items in order: {inters}. What else do you think the user need?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "Here is the item interaction history of the user {user}: {inters}, what to recommend to the user next?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "Which item would the user {user} be likely to interact with next after interacting with items {inters}?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "By analyzing the user {user}'s historical interactions with items {inters}, what is the next expected interaction item?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "After interacting with items {inters}, what is the next item that could be recommended for the user {user}?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "Given the user {user}'s historical interactive items arranged in chronological order: {inters}, can you recommend a suitable item for the user?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——10 +prompt = {} +prompt["instruction"] = "Considering the user {user} has interacted with items {inters}. What is the next recommendation for the user?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——11 +prompt = {} +prompt["instruction"] = "What is the top recommended item for the user {user} who has previously interacted with items {inters} in order?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——12 +prompt = {} +prompt["instruction"] = "The user {user} has interacted with the following items in the past in order: {inters}. Please predict the next item that the user most desires based on the given interaction records." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——13 +prompt = {} +prompt["instruction"] = "Using the user {user}'s historical interactions as input data, suggest the next item that the user is highly likely to enjoy. The historical interactions are provided as follows: {inters}." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——14 +prompt = {} +prompt["instruction"] = "You can access the user {user}'s historical item interaction records: {inters}. Now your task is to recommend the next potential item to him, considering his past interactions." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——15 +prompt = {} +prompt["instruction"] = "You have observed that the user {user} has interacted with the following items: {inters}, please recommend a next item that you think would be suitable for the user." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——16 +prompt = {} +prompt["instruction"] = "You have obtained the ordered list of user historical interaction items, which is as follows: {inters}. Using this history as a reference, please select the next item to recommend to the user {user}." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +all_prompt["seqrec"] = seqrec_prompt + + + +# ======================================================== +# Task 2 -- Item2Index -- 19 Prompt +# ======================================================== +# Remove periods when inputting + +item2index_prompt = [] + +# ======================================================== +# Title2Index + +#####——0 +prompt = {} +prompt["instruction"] = "Which item has the title: \"{title}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Which item is assigned the title: \"{title}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "An item is called \"{title}\", could you please let me know which item it is?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Which item is called \"{title}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "One of the items is named \"{title}\", can you tell me which item this is?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "What is the item that goes by the title \"{title}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +# ======================================================== +# Description2Index + +#####——6 +prompt = {} +prompt["instruction"] = "An item can be described as follows: \"{description}\". Which item is it describing?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Can you tell me what item is described as \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "Can you provide the item that corresponds to the following description: \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "Which item has the following characteristics: \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——10 +prompt = {} +prompt["instruction"] = "Which item is characterized by the following description: \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——11 +prompt = {} +prompt["instruction"] = "I am curious to know which item can be described as follows: \"{description}\". Can you tell me?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +# ======================================================== +# Title and Description to index + +#####——12 +prompt = {} +prompt["instruction"] = "An item is called \"{title}\" and described as \"{description}\", can you tell me which item it is?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——13 +prompt = {} +prompt["instruction"] = "Could you please identify what item is called \"{title}\" and described as \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——14 +prompt = {} +prompt["instruction"] = "Which item is called \"{title}\" and has the characteristics described below: \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——15 +prompt = {} +prompt["instruction"] = "Please show me which item is named \"{title}\" and its corresponding description is: \"{description}\"." +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——16 +prompt = {} +prompt["instruction"] = "Determine which item this is by its title and description. The title is: \"{title}\", and the description is: \"{description}\"." +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——17 +prompt = {} +prompt["instruction"] = "Based on the title: \"{title}\", and the description: \"{description}\", answer which item is this?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——18 +prompt = {} +prompt["instruction"] = "Can you identify the item from the provided title: \"{title}\", and description: \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +all_prompt["item2index"] = item2index_prompt + + +# ======================================================== +# Task 3 -- Index2Item --17 Prompt +# ======================================================== +# Remove periods when inputting + +index2item_prompt = [] + +# ======================================================== +# Index2Title + +#####——0 +prompt = {} +prompt["instruction"] = "What is the title of item {item}?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "What title is assigned to item {item}?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Could you please tell me what item {item} is called?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Can you provide the title of item {item}?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "What item {item} is referred to as?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "Would you mind informing me about the title of item {item}?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +# ======================================================== +# Index2Description + +#####——6 +prompt = {} +prompt["instruction"] = "Please provide a description of item {item}." +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Briefly describe item {item}." +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "Can you share with me the description corresponding to item {item}?" +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "What is the description of item {item}?" +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + +#####——10 +prompt = {} +prompt["instruction"] = "How to describe the characteristics of item {item}?" +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + +#####——11 +prompt = {} +prompt["instruction"] = "Could you please tell me what item {item} looks like?" +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + + +# ======================================================== +# index to Title and Description + +#####——12 +prompt = {} +prompt["instruction"] = "What is the title and description of item {item}?" +prompt["response"] = "{title}\n\n{description}" +index2item_prompt.append(prompt) + +#####——13 +prompt = {} +prompt["instruction"] = "Can you provide the corresponding title and description for item {item}?" +prompt["response"] = "{title}\n\n{description}" +index2item_prompt.append(prompt) + +#####——14 +prompt = {} +prompt["instruction"] = "Please tell me what item {item} is called, along with a brief description of it." +prompt["response"] = "{title}\n\n{description}" +index2item_prompt.append(prompt) + +#####——15 +prompt = {} +prompt["instruction"] = "Would you mind informing me about the title of the item {item} and how to describe its characteristics?" +prompt["response"] = "{title}\n\n{description}" +index2item_prompt.append(prompt) + +#####——16 +prompt = {} +prompt["instruction"] = "I need to know the title and description of item {item}. Could you help me with that?" +prompt["response"] = "{title}\n\n{description}" +index2item_prompt.append(prompt) + +all_prompt["index2item"] = index2item_prompt + + + + + +# ======================================================== +# Task 4 -- Interactions2Title -- Prompt +# ======================================================== + + +inters2title_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "The user {user} has sequentially interacted with items {inters}. Can you recommend the next item for him? Tell me the title of the item?" +prompt["response"] = "{title}" +inters2title_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Based on the user {user}'s historical interactions: {inters}, try to predict the title of the item that the user may need next." +prompt["response"] = "{title}" +inters2title_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Utilizing the user {user}'s past ordered interactions, which include items {inters}, please recommend the next item you think is suitable for the user and provide its title." +prompt["response"] = "{title}" +inters2title_prompt.append(prompt) + + +#####——3 +prompt = {} +prompt["instruction"] = "After interacting with items {inters}, what is the most probable item for the user {user} to interact with next? Kindly provide the item's title." +prompt["response"] = "{title}" +inters2title_prompt.append(prompt) + +all_prompt["inters2title"] = inters2title_prompt + + +# ======================================================== +# Task 5 -- Interactions2Description -- Prompt +# ======================================================== + +inters2description_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "Please review the user {user}'s historical interactions: {inters}, and describe what kind of item he still needs." +prompt["response"] = "{description}" +inters2description_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Here is the item interaction history of the user {user}: {inters}, please tell me what features he expects from his next item." +prompt["response"] = "{description}" +inters2description_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "By analyzing the user {user}'s historical interactions with items {inters}, can you infer what the user's next interactive item will look like?" +prompt["response"] = "{description}" +inters2description_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Access the user {user}'s historical item interaction records: {inters}. Your objective is to describe the next potential item for him, taking into account his past interactions." +prompt["response"] = "{description}" +inters2description_prompt.append(prompt) + +all_prompt["inters2description"] = inters2description_prompt + + +# ======================================================== +# Task 6 -- InteractedTitles2Item -- Prompt +# ======================================================== + +intertitles2item_prompt = [] +#####——0 +prompt = {} +prompt["instruction"] = "Given the title sequence of user {user} historical interactive items: {inter_titles}, can you recommend a suitable next item for the user?" +prompt["response"] = "{item}" +intertitles2item_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "I possess a user {user}'s past interaction history, denoted by the title sequence of interactive items: {inter_titles}, and I am interested in knowing the user's next most desired item. Can you help me?" +prompt["response"] = "{item}" +intertitles2item_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Considering the title sequence of user {user} history interaction items: {inter_titles}. What is the next recommendation for the user?" +prompt["response"] = "{item}" +intertitles2item_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "You have obtained the ordered title list of user {user} historical interaction items, as follows: {inter_titles}. Based on this historical context, kindly choose the subsequent item for his recommendation." +prompt["response"] = "{item}" +intertitles2item_prompt.append(prompt) + +all_prompt["intertitles2item"] = intertitles2item_prompt + + +# ======================================================== +# Task 7 -- ItemSearch -- Prompt +# ======================================================== + +itemsearch_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "Here is the historical interactions of a user: {inters}. And his personalized preferences are as follows: \"{explicit_preference}\". Your task is to recommend an item that is consistent with the user's preference." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "The user has interacted with a list of items, which are as follows: {inters}. Based on these interacted items, the user current intent is as follows \"{user_related_intention}\", and your task is to generate an item that matches the user's current intent." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "As a recommender system, you are assisting a user who has recently interacted with the following items: {inters}. The user expresses a desire to obtain another item with the following characteristics: \"{item_related_intention}\". Please recommend an item that meets these criteria." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Using the user current query: \"{query}\" and his historical interactions: {inters}, you can estimate the user's preferences \"{explicit_preference}\". Please respond to the user's query by selecting an item that best matches his preference and query." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "The user needs a new item and searches for: \"{query}\". In addition, he has previously interacted with: {inters}. You can obtain his preference by analyzing his historical interactions: \"{explicit_preference}\". Can you recommend an item that best matches the search query and preferences?" +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "Based on the user historical interactions with the following items: {inters}. You can infer his preference by observing the historical interactions: \"{explicit_preference}\". Now the user wants a new item and searches for: \"{query}\". Please select a suitable item that matches his preference and search intent." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +all_prompt["itemsearch"] = itemsearch_prompt + +# ======================================================== +# Task 8 -- Query2Item -- Prompt +# ======================================================== + +query2item_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "Suppose you are a search engine, now a user {user} searches that: \"{query}\", can you select an item to respond to the user's query?" +prompt["response"] = "{item}" +query2item_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "As a search engine, your task is to answer the user's query by generating a related item. The user {user}'s query is provided as \"{query}\". Please provide your generated item as your answer." +prompt["response"] = "{item}" +query2item_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "As a recommender system, your task is to recommend an item that is related to the user {user}'s request, which is specified as follows: \"{query}\". Please provide your recommendation." +prompt["response"] = "{item}" +query2item_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "You meet user {user}'s query: \"{query}\". Please respond to this user by selecting an appropriate item." +prompt["response"] = "{item}" +query2item_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "Your task is to recommend the best item that matches the user's query. Here is the search query of the user {user}: \"{query}\", tell me the item you recommend." +prompt["response"] = "{item}" +query2item_prompt.append(prompt) + +all_prompt["query2item"] = query2item_prompt + + +# ======================================================== +# Task 9 -- PreferenceObtain -- Prompt +# ======================================================== + +preferenceobtain_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "The user has interacted with items {inters} in chronological order. Please estimate his preferences." +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Based on the items that the user has interacted with: {inters}, can you infer what preferences he has?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Can you provide a summary of the user preferences based on his historical interactions: {inters}?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "After interacting with items {inters} in order, what preferences do you think the user has?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "Here is the item interaction history of the user: {inters}, could you please infer the user preferences." +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "Based on the user historical interaction records: {inters}, what are your speculations about his preferences?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Given the user historical interactive items arranged in chronological order: {inters}, what can be inferred about the preferences of the user?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "Can you speculate on the user preferences based on his historical item interaction records: {inters}?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "What is the preferences of user who has previously interacted with items {inters} sequentially?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——10 +prompt = {} +prompt["instruction"] = "Using the user historical interactions as input data, summarize the user's preferences. The historical interactions are provided as follows: {inters}." +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——11 +prompt = {} +prompt["instruction"] = "Utilizing the ordered list of the user historical interaction items as a reference, please make an informed estimation of the user's preferences. The historical interactions are as follows: {inters}." +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +all_prompt["preferenceobtain"] = preferenceobtain_prompt diff --git a/prompt_finetune.py b/prompt_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..45a3d94da22f2940b8ab42de10883f4ab085cbd1 --- /dev/null +++ b/prompt_finetune.py @@ -0,0 +1,803 @@ +sft_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request." \ + "\n\n### Instruction:\n{instruction}\n\n### Response:{response}" + +all_prompt = {} + +# ===================================================== +# Task 12 -- User2Preference -- 8 Prompt +# ===================================================== + +user2pref_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "What is the preference of user {user}?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Would you mind informing me about the preference of user {user}?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Briefly summarize the preference of user {user}." +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Can you share with me the preference description corresponding to user {user}?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "How to describe the preference of user {user}?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "I need to know the preference of user {user}. Could you help me with that?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "Please provide a description of user {user}'s preference." +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Can you provide the corresponding description for user {user}'s preference?" +prompt["response"] = "{preference}" +user2pref_prompt.append(prompt) + +all_prompt["user2pref"] = user2pref_prompt + +# ===================================================== +# Task 11 -- Preference2User -- 8 Prompt +# ===================================================== + +pref2user_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "A user has the following preference: \"{preference}\". Which user is it describing?" +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Can you tell which user has such a preference: \"{preference}\"?" +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Your task is to determine the corresponding user based on his preference. Here is the preference description: \"{preference}\"." +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Please identify the user from the provided preference: \"{preference}\"." +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "Which user has the following preference: \"{preference}\"?" +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "Given the textual description of someone's preference as \"{preference}\", identify the corresponding user." +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "Based on the provided preference \"{preference}\", answer which user is it referring to?" +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Which user can be characterized by the following description: \"{preference}\"?" +prompt["response"] = "{user}" +pref2user_prompt.append(prompt) + +all_prompt["pref2user"] = pref2user_prompt + +# ===================================================== +# Task 10 -- Possible User Prediction -- 10 Prompt +# ===================================================== + +usersearch_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "The item {item} has been historically clicked by users {users}. Can you predict another possible user which will click this item?" +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Given the item {item} and its historical interactive users {users}, I want to know which user will click this item. Please provide a reccomendation." +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Here are the item {item}'s historical interactive users: {users}, try to predict another user that is fond of this item." +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "According to the users {users} that have clicked the item {item}, can you determine the next possible user wanting the same item?" +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "By analyzing the item {item}'s historical interactions with users {users}, who is the next expected interactive user?" +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "After clicked by these users {users}, who is the next user that may be keen on the item {item}?" +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "Who is the most potential user for the given item {item} that has previously interacted with the following users {users}?" +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Taking the item {item}'s historical interactions as condition, predict the next user that may highly enjoy the same item. Here is the historical interactive users {users}." +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "You have access to the item {item}'s historical user interaction record {users}. Now your task is to predict another possible user that loves the same item based on the past interaction." +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "Considering the fact that several users {users} have clicked the same item {item}, forecast who is the next user that will be insterested in this item." +prompt["response"] = "{user}" +usersearch_prompt.append(prompt) + +all_prompt["usersearch"] = usersearch_prompt + +# ===================================================== +# Task 1 -- Sequential Recommendation -- 17 Prompt +# ===================================================== + +seqrec_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "The user {user} has interacted with items {inters} in chronological order. Can you predict the next possible item that the user may expect?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "I find the user {user}'s historical interactive items: {inters}, and I want to know what next item the user needs. Can you help me decide?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Here are the user {user}'s historical interactions: {inters}, try to recommend another item to the user. Note that the historical interactions are arranged in chronological order." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Based on the items that the user {user} has interacted with: {inters}, can you determine what item would be recommended to him next?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "The user {user} has interacted with the following items in order: {inters}. What else do you think the user need?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "Here is the item interaction history of the user {user}: {inters}, what to recommend to the user next?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "Which item would the user {user} be likely to interact with next after interacting with items {inters}?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "By analyzing the user {user}'s historical interactions with items {inters}, what is the next expected interaction item?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "After interacting with items {inters}, what is the next item that could be recommended for the user {user}?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "Given the user {user}'s historical interactive items arranged in chronological order: {inters}, can you recommend a suitable item for the user?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——10 +prompt = {} +prompt["instruction"] = "Considering the user {user} has interacted with items {inters}. What is the next recommendation for the user?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——11 +prompt = {} +prompt["instruction"] = "What is the top recommended item for the user {user} who has previously interacted with items {inters} in order?" +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——12 +prompt = {} +prompt["instruction"] = "The user {user} has interacted with the following items in the past in order: {inters}. Please predict the next item that the user most desires based on the given interaction records." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——13 +prompt = {} +prompt["instruction"] = "Using the user {user}'s historical interactions as input data, suggest the next item that the user is highly likely to enjoy. The historical interactions are provided as follows: {inters}." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——14 +prompt = {} +prompt["instruction"] = "You can access the user {user}'s historical item interaction records: {inters}. Now your task is to recommend the next potential item to him, considering his past interactions." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——15 +prompt = {} +prompt["instruction"] = "You have observed that the user {user} has interacted with the following items: {inters}, please recommend a next item that you think would be suitable for the user." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +#####——16 +prompt = {} +prompt["instruction"] = "You have obtained the ordered list of user {user} historical interaction items, which is as follows: {inters}. Using this history as a reference, please select the next item to recommend to the user." +prompt["response"] = "{item}" +seqrec_prompt.append(prompt) + +all_prompt["seqrec"] = seqrec_prompt + + + +# ======================================================== +# Task 2 -- Item2Index -- 19 Prompt +# ======================================================== +# Remove periods when inputting + +item2index_prompt = [] + +# ======================================================== +# Title2Index + +#####——0 +prompt = {} +prompt["instruction"] = "Which item has the title: \"{title}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Which item is assigned the title: \"{title}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "An item is called \"{title}\", could you please let me know which item it is?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Which item is called \"{title}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "One of the items is named \"{title}\", can you tell me which item this is?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "What is the item that goes by the title \"{title}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +# prompt = {} +# prompt["instruction"] = "Which item is referred to as \"{title}\"?" +# prompt["response"] = "{item}" +# item2index_prompt.append(prompt) + +# ======================================================== +# Description2Index + +#####——6 +prompt = {} +prompt["instruction"] = "An item can be described as follows: \"{description}\". Which item is it describing?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Can you tell me what item is described as \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "Can you provide the item that corresponds to the following description: \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + + +# prompt = {} +# prompt["instruction"] = "What is the item described as follows: \"{description}\"?" +# prompt["response"] = "{item}" +# item2index_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "Which item has the following characteristics: \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——10 +prompt = {} +prompt["instruction"] = "Which item is characterized by the following description: \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——11 +prompt = {} +prompt["instruction"] = "I am curious to know which item can be described as follows: \"{description}\". Can you tell me?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +# ======================================================== +# Title and Description to index + +#####——12 +prompt = {} +prompt["instruction"] = "An item is called \"{title}\" and described as \"{description}\", can you tell me which item it is?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——13 +prompt = {} +prompt["instruction"] = "Could you please identify what item is called \"{title}\" and described as \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——14 +prompt = {} +prompt["instruction"] = "Which item is called \"{title}\" and has the characteristics described below: \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——15 +prompt = {} +prompt["instruction"] = "Please show me which item is named \"{title}\" and its corresponding description is: \"{description}\"." +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + + +# prompt = {} +# prompt["instruction"] = "Here is an item called \"{title}\" and described as \"{description}\". Which item is it?" +# prompt["response"] = "{item}" +# item2index_prompt.append(prompt) + +#####——16 +prompt = {} +prompt["instruction"] = "Determine which item this is by its title and description. The title is: \"{title}\", and the description is: \"{description}\"." +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——17 +prompt = {} +prompt["instruction"] = "Based on the title: \"{title}\", and the description: \"{description}\", answer which item is this?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +#####——18 +prompt = {} +prompt["instruction"] = "Can you identify the item from the provided title: \"{title}\", and description: \"{description}\"?" +prompt["response"] = "{item}" +item2index_prompt.append(prompt) + +all_prompt["item2index"] = item2index_prompt + + +# ======================================================== +# Task 3 -- Index2Item --17 Prompt +# ======================================================== +# Remove periods when inputting + +index2item_prompt = [] + +# ======================================================== +# Index2Title + +#####——0 +prompt = {} +prompt["instruction"] = "What is the title of item {item}?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "What title is assigned to item {item}?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Could you please tell me what item {item} is called?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Can you provide the title of item {item}?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "What item {item} is referred to as?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "Would you mind informing me about the title of item {item}?" +prompt["response"] = "{title}" +index2item_prompt.append(prompt) + +# ======================================================== +# Index2Description + +#####——6 +prompt = {} +prompt["instruction"] = "Please provide a description of item {item}." +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Briefly describe item {item}." +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "Can you share with me the description corresponding to item {item}?" +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "What is the description of item {item}?" +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + +#####——10 +prompt = {} +prompt["instruction"] = "How to describe the characteristics of item {item}?" +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + +#####——11 +prompt = {} +prompt["instruction"] = "Could you please tell me what item {item} looks like?" +prompt["response"] = "{description}" +index2item_prompt.append(prompt) + + +# ======================================================== +# index to Title and Description + +#####——12 +prompt = {} +prompt["instruction"] = "What is the title and description of item {item}?" +prompt["response"] = "{title}\n\n{description}" +index2item_prompt.append(prompt) + +#####——13 +prompt = {} +prompt["instruction"] = "Can you provide the corresponding title and description for item {item}?" +prompt["response"] = "{title}\n\n{description}" +index2item_prompt.append(prompt) + +#####——14 +prompt = {} +prompt["instruction"] = "Please tell me what item {item} is called, along with a brief description of it." +prompt["response"] = "{title}\n\n{description}" +index2item_prompt.append(prompt) + +#####——15 +prompt = {} +prompt["instruction"] = "Would you mind informing me about the title of the item {item} and how to describe its characteristics?" +prompt["response"] = "{title}\n\n{description}" +index2item_prompt.append(prompt) + +#####——16 +prompt = {} +prompt["instruction"] = "I need to know the title and description of item {item}. Could you help me with that?" +prompt["response"] = "{title}\n\n{description}" +index2item_prompt.append(prompt) + +all_prompt["index2item"] = index2item_prompt + + + + + +# ======================================================== +# Task 4 -- FusionSequentialRec -- Prompt +# ======================================================== + + +fusionseqrec_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "The user {user} has sequentially interacted with items {inters}. Can you recommend the next item for him? Tell me the title of the item?" +prompt["response"] = "{title}" +fusionseqrec_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Based on the user {user}'s historical interactions: {inters}, try to predict the title of the item that the user may need next." +prompt["response"] = "{title}" +fusionseqrec_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "Utilizing the user {user}'s past ordered interactions, which include items {inters}, please recommend the next item you think is suitable for the user and provide its title." +prompt["response"] = "{title}" +fusionseqrec_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "After interacting with items {inters}, what is the most probable item for the user {user} to interact with next? Kindly provide the item's title." +prompt["response"] = "{title}" +fusionseqrec_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "Please review the user {user}'s historical interactions: {inters}, and describe what kind of item he still needs." +prompt["response"] = "{description}" +fusionseqrec_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "Here is the item interaction history of the user {user}: {inters}, please tell me what features he expects from his next item." +prompt["response"] = "{description}" +fusionseqrec_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "By analyzing the user {user}'s historical interactions with items {inters}, can you infer what the user's next interactive item will look like?" +prompt["response"] = "{description}" +fusionseqrec_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Access the user {user}'s historical item interaction records: {inters}. Your objective is to describe the next potential item for him, taking into account his past interactions." +prompt["response"] = "{description}" +fusionseqrec_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "Given the title sequence of user {user} historical interactive items: {inter_titles}, can you recommend a suitable next item for the user?" +prompt["response"] = "{item}" +fusionseqrec_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "I possess a user {user}'s past interaction history, denoted by the title sequence of interactive items: {inter_titles}, and I am interested in knowing the user's next most desired item. Can you help me?" +prompt["response"] = "{item}" +fusionseqrec_prompt.append(prompt) + +#####——10 +prompt = {} +prompt["instruction"] = "Considering the title sequence of user {user} history interaction items: {inter_titles}. What is the next recommendation for the user?" +prompt["response"] = "{item}" +fusionseqrec_prompt.append(prompt) + +#####——11 +prompt = {} +prompt["instruction"] = "You have obtained the ordered title list of user {user} historical interaction items, as follows: {inter_titles}. Based on this historical context, kindly choose the subsequent item for user recommendation." +prompt["response"] = "{item}" +fusionseqrec_prompt.append(prompt) + + +all_prompt["fusionseqrec"] = fusionseqrec_prompt + +# ======================================================== +# Task 5 -- ItemSearch -- Prompt +# ======================================================== + + +itemsearch_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "Here is the historical interactions of a user {user}: {inters}. And his personalized preferences are as follows: \"{explicit_preference}\". Your task is to recommend an item that is consistent with the user's preference." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "The user {user} has interacted with a list of items, which are as follows: {inters}. Based on these interacted items, the user current intent is as follows \"{user_related_intention}\", and your task is to generate an item that matches the user's current intent." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——2 +prompt = {} +prompt["instruction"] = "As a recommender system, you are assisting a user {user} who has recently interacted with the following items: {inters}. The user expresses a desire to obtain another item with the following characteristics: \"{item_related_intention}\". Please recommend an item that meets these criteria." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Using the user {user}'s current query: \"{query}\" and his historical interactions: {inters}, you can estimate the user's preferences \"{explicit_preference}\". Please respond to the user's query by selecting an item that best matches his preference and query." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "The user {user} needs a new item and searches for: \"{query}\". In addition, he has previously interacted with: {inters}. You can obtain his preference by analyzing his historical interactions: \"{explicit_preference}\". Can you recommend an item that best matches the search query and preferences?" +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "Based on the user {user}'s historical interactions with the following items: {inters}. You can infer his preference by observing the historical interactions: \"{explicit_preference}\". Now the user wants a new item and searches for: \"{query}\". Please select a suitable item that matches his preference and search intent." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "Suppose you are a search engine, now a user {user} searches that: \"{query}\", can you select an item to respond to the user's query?" +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "As a search engine, your task is to answer the user's query by generating a related item. The user {user}'s query is provided as \"{query}\". Please provide your generated item as your answer." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "As a recommender system, your task is to recommend an item that is related to the user {user}'s request, which is specified as follows: \"{query}\". Please provide your recommendation." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "You meet a user {user}'s query: \"{query}\". Please respond to this user by selecting an appropriate item." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + + +#####——10 +prompt = {} +prompt["instruction"] = "Your task is to recommend the best item that matches the user {user}'s query. Here is the search query of the user: \"{query}\", tell me the item you recommend." +prompt["response"] = "{item}" +itemsearch_prompt.append(prompt) + +all_prompt["itemsearch"] = itemsearch_prompt + +# ======================================================== +# Task 6 -- PreferenceObtain -- Prompt +# ======================================================== + +preferenceobtain_prompt = [] + +#####——0 +prompt = {} +prompt["instruction"] = "The user {user} has interacted with items {inters} in chronological order. Please estimate his preferences." +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——1 +prompt = {} +prompt["instruction"] = "Based on the items that the user {user} has interacted with: {inters}, can you infer what preferences he has?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——3 +prompt = {} +prompt["instruction"] = "Can you provide a summary of the user {user}'s preferences based on his historical interactions: {inters}?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——4 +prompt = {} +prompt["instruction"] = "After interacting with items {inters} in order, what preferences do you think the user {user} has?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——5 +prompt = {} +prompt["instruction"] = "Here is the item interaction history of the user: {inters}, could you please infer the user {user}'s preferences." +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——6 +prompt = {} +prompt["instruction"] = "Based on the user {user}'s historical interaction records: {inters}, what are your speculations about his preferences?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——7 +prompt = {} +prompt["instruction"] = "Given the user {user}'s historical interactive items arranged in chronological order: {inters}, what can be inferred about the preferences of the user?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——8 +prompt = {} +prompt["instruction"] = "Can you speculate on the user {user}'s preferences based on his historical item interaction records: {inters}?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——9 +prompt = {} +prompt["instruction"] = "What is the preferences of a user {user} who has previously interacted with items {inters} sequentially?" +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——10 +prompt = {} +prompt["instruction"] = "Using the user {user}'s historical interactions as input data, summarize the user's preferences. The historical interactions are provided as follows: {inters}." +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +#####——11 +prompt = {} +prompt["instruction"] = "Utilizing the ordered list of the user {user}'s historical interaction items as a reference, please make an informed estimation of the user's preferences. The historical interactions are as follows: {inters}." +prompt["response"] = "{explicit_preference}" +preferenceobtain_prompt.append(prompt) + +all_prompt["preferenceobtain"] = preferenceobtain_prompt diff --git a/rq_llama.py b/rq_llama.py new file mode 100644 index 0000000000000000000000000000000000000000..7090cc06d57d83d7b895e6b6a09cfd6b780cb117 --- /dev/null +++ b/rq_llama.py @@ -0,0 +1,569 @@ +import os +import json +import copy +import wandb +import torch +import torch.nn as nn +import transformers +from transformers import LlamaPreTrainedModel, LlamaForCausalLM, LlamaTokenizer, LlamaConfig + +from peft import ( + TaskType, + LoraConfig, + get_peft_model, + get_peft_model_state_dict, + set_peft_model_state_dict, +) + +from index.models import * +from index.models.rqvae import RQVAE + +from torch.nn.init import xavier_normal_ +from sklearn.cluster import KMeans + +class LlamaWithRQ(LlamaPreTrainedModel): + def __init__(self, config): + super().__init__(config) + + args = config.args + + tokenizer = LlamaTokenizer.from_pretrained( + args['base_model'], + model_max_length = args['model_max_length'], + padding_side ="right", + ) + tokenizer.pad_token_id = 0 + + item_tokens = [] + prefix = ['','','','',''] + for i in range(len(args['num_emb_list'])): + item_tokens.extend([prefix[i].format(int(x)) for x in range(args['num_emb_list'][i])]) + self.prefix = prefix + + user_tokens = [] + user_prefix = ['','','','',''] + for i in range(len(args['num_emb_list'])): + user_tokens.extend([user_prefix[i].format(int(x)) for x in range(args['num_emb_list'][i])]) + self.user_prefix = user_prefix + + tokenizer.add_tokens(item_tokens) + tokenizer.add_tokens(user_tokens) + config.vocab_size = len(tokenizer) + + llama_model = LlamaForCausalLM.from_pretrained(args['base_model']) + llama_model.resize_token_embeddings(len(tokenizer)) + + lora_config = LoraConfig( + r = args['lora_r'], + lora_alpha = args['lora_alpha'], + target_modules = args['lora_target_modules'].split(","), + modules_to_save = args['lora_modules_to_save'].split(","), + lora_dropout = args['lora_dropout'], + bias = "none", + inference_mode = False, + task_type = TaskType.CAUSAL_LM + ) + llama_model = get_peft_model(llama_model, lora_config) + + for n, p in llama_model.named_parameters(): + if "original_module" in n and any(module_name in n for module_name in lora_config.modules_to_save): + p.requires_grad = False + + self.tokenizer = tokenizer + self.model = llama_model + + item_json = os.path.join(args['data_path'], args['dataset'], args['dataset'] + ".item.json") + with open(item_json, 'r') as f: + self.item_texts = json.load(f) + + user_json = os.path.join(args['data_path'], args['dataset'], args['dataset'] + ".user.json") + with open(user_json, 'r') as f: + self.user_texts = json.load(f) + self.user_texts = self.user_texts['user_explicit_preference'] + + self.item_rqvae = RQVAE( + in_dim = config.hidden_size, + num_emb_list = args['num_emb_list'], + e_dim = args['e_dim'], + layers = args['layers'], + dropout_prob = args['dropout_prob'], + bn = args['bn'], + loss_type = args['loss_type'], + quant_loss_weight = args['quant_loss_weight'], + kmeans_init = args['kmeans_init'], + kmeans_iters = args['kmeans_iters'], + sk_epsilons = args['sk_epsilons'], + sk_iters = args['sk_iters']) + + self.user_rqvae = RQVAE( + in_dim = config.hidden_size, + num_emb_list = args['num_emb_list'], + e_dim = args['e_dim'], + layers = args['layers'], + dropout_prob = args['dropout_prob'], + bn = args['bn'], + loss_type = args['loss_type'], + quant_loss_weight = args['quant_loss_weight'], + kmeans_init = args['kmeans_init'], + kmeans_iters = args['kmeans_iters'], + sk_epsilons = args['sk_epsilons'], + sk_iters = args['sk_iters']) + + # self.projector = nn.Linear(args['e_dim'], config.hidden_size) + # self.item_projector = nn.Linear(args['e_dim'], config.hidden_size) + # self.user_projector = nn.Linear(args['e_dim'], config.hidden_size) + self.args = args + + def rqvae_forward(self, inputs, targets, inters, item, users, user, task): + llama_model = self.model.get_decoder() + if task.lower() == 'seqrec': + # inters, user, item + + inter_feature_list = [] + inter_emb_list = [] + inter_item_list = inters.split(',') + for j in range(len(inter_item_list)): + inter_feature = self.item_texts[inter_item_list[j]]['title'] + ' ' + self.item_texts[inter_item_list[j]]['description'] + inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) + inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask) + inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1) + inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True) + inter_emb_list.append(inter_emb.detach()) + inter_embs = torch.cat(inter_emb_list, dim = 0) + + item_feature = self.item_texts[item]['title'] + ' ' + self.item_texts[item]['description'] + item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) + item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask) + item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1) + item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True) + item_emb = item_emb.detach() + + user_feature = " ".join(self.user_texts[user]) + user_ids = self.tokenizer(user_feature, return_tensors = 'pt', padding = True, truncation = True).to(self.model.device) + user_emb = llama_model(input_ids = user_ids.input_ids, attention_mask = user_ids.attention_mask) + user_emb = user_emb.last_hidden_state * user_ids.attention_mask.unsqueeze(-1) + user_emb = user_emb.sum(dim=1) / user_ids.attention_mask.sum(dim = -1, keepdim = True) + user_emb = user_emb.detach() + + item_rec_embs, item_rq_loss, item_rqids = self.item_rqvae(torch.cat([inter_embs, item_emb], dim = 0)) + item_rqvae_loss, item_rec_loss = self.item_rqvae.compute_loss(item_rec_embs, item_rq_loss, torch.cat([inter_embs, item_emb], dim = 0)) + + user_rec_emb, user_rq_loss, user_rqids = self.user_rqvae(user_emb) + user_rqvae_loss, user_rec_loss = self.user_rqvae.compute_loss(user_rec_emb, user_rq_loss, user_emb) + + inters_rqids = item_rqids.view(-1, item_rqids.shape[-1]).cpu().numpy().tolist()[:-1] + item_rqid = item_rqids.view(-1, item_rqids.shape[-1]).cpu().numpy().tolist()[-1] + user_rqid = user_rqids.view(-1, user_rqids.shape[-1]).cpu().numpy().tolist()[0] + + text_rqids = {} + code = '' + for rqid in inters_rqids: + for k, idx in enumerate(rqid): + code = code + self.prefix[k].format(idx) + code = code + ', ' + text_rqids['inters'] = code[:-2] + code = '' + for k, idx in enumerate(item_rqid): + code = code + self.prefix[k].format(idx) + text_rqids['item'] = code + code = '' + for k, idx in enumerate(user_rqid): + code = code + self.user_prefix[k].format(idx) + text_rqids['user'] = code + + inputs = inputs.format(inters = text_rqids['inters'], user = text_rqids['user']) + targets = targets.format(inters = text_rqids['inters'], user = text_rqids['user'], item = text_rqids['item']) + + num_item = item_rec_embs.shape[0] + num_user = user_rec_emb.shape[0] + + elif task.lower() == 'itemsearch': + # inters, item + inter_feature_list = [] + inter_emb_list = [] + inter_item_list = inters.split(',') + for j in range(len(inter_item_list)): + inter_feature = self.item_texts[inter_item_list[j]]['title'] + ' ' + self.item_texts[inter_item_list[j]]['description'] + inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) + inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask) + inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1) + inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True) + inter_emb_list.append(inter_emb.detach()) + inter_embs = torch.cat(inter_emb_list, dim = 0) + + item_feature = self.item_texts[item]['title'] + ' ' + self.item_texts[item]['description'] + item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) + item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask) + item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1) + item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True) + item_emb = item_emb.detach() + + # user_feature = " ".join(self.user_texts[user]) + # user_ids = self.tokenizer(user_feature, return_tensors = 'pt', padding = True, truncation = True).to(self.model.device) + # user_emb = llama_model(input_ids = user_ids.input_ids, attention_mask = user_ids.attention_mask) + # user_emb = user_emb.last_hidden_state * user_ids.attention_mask.unsqueeze(-1) + # user_emb = user_emb.sum(dim=1) / user_ids.attention_mask.sum(dim = -1, keepdim = True) + # user_emb = user_emb.detach() + + item_rec_embs, item_rq_loss, item_rqids = self.item_rqvae(torch.cat([inter_embs, item_emb], dim = 0)) + item_rqvae_loss, item_rec_loss = self.item_rqvae.compute_loss(item_rec_embs, item_rq_loss, torch.cat([inter_embs, item_emb], dim = 0)) + + # user_rec_emb, user_rq_loss, user_rqids = self.user_rqvae(user_emb) + # user_rqvae_loss, user_rec_loss = self.user_rqvae.compute_loss(user_rec_emb, user_rq_loss, user_emb) + + inters_rqids = item_rqids.view(-1, item_rqids.shape[-1]).cpu().numpy().tolist()[:-1] + item_rqid = item_rqids.view(-1, item_rqids.shape[-1]).cpu().numpy().tolist()[-1] + # user_rqid = user_rqids.view(-1, user_rqids.shape[-1]).cpu().numpy().tolist()[0] + + text_rqids = {} + code = '' + for rqid in inters_rqids: + for k, idx in enumerate(rqid): + code = code + self.prefix[k].format(idx) + code = code + ', ' + text_rqids['inters'] = code[:-2] + code = '' + for k, idx in enumerate(item_rqid): + code = code + self.prefix[k].format(idx) + text_rqids['item'] = code + # code = '' + # for k, idx in enumerate(user_rqid): + # code = code + self.user_prefix[k].format(idx) + # text_rqids['user'] = code + + inputs = inputs.format(inters = text_rqids['inters']) + targets = targets.format(inters = text_rqids['inters'], item = text_rqids['item']) + + num_item = item_rec_embs.shape[0] + num_user = 0 + user_rqvae_loss = 0 + + elif task.lower() in ['inters2title','inters2description']: + # inputs, targets, inters, user + inter_feature_list = [] + inter_emb_list = [] + inter_item_list = inters.split(',') + for j in range(len(inter_item_list)): + inter_feature = self.item_texts[inter_item_list[j]]['title'] + ' ' + self.item_texts[inter_item_list[j]]['description'] + inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) + inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask) + inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1) + inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True) + inter_emb_list.append(inter_emb.detach()) + inter_embs = torch.cat(inter_emb_list, dim = 0) + + user_feature = " ".join(self.user_texts[user]) + user_ids = self.tokenizer(user_feature, return_tensors = 'pt', padding = True, truncation = True).to(self.model.device) + user_emb = llama_model(input_ids = user_ids.input_ids, attention_mask = user_ids.attention_mask) + user_emb = user_emb.last_hidden_state * user_ids.attention_mask.unsqueeze(-1) + user_emb = user_emb.sum(dim=1) / user_ids.attention_mask.sum(dim = -1, keepdim = True) + user_emb = user_emb.detach() + + item_rec_embs, item_rq_loss, item_rqids = self.item_rqvae(inter_embs) + item_rqvae_loss, item_rec_loss = self.item_rqvae.compute_loss(item_rec_embs, item_rq_loss, inter_embs) + + user_rec_emb, user_rq_loss, user_rqids = self.user_rqvae(user_emb) + user_rqvae_loss, user_rec_loss = self.user_rqvae.compute_loss(user_rec_emb, user_rq_loss, user_emb) + + inters_rqids = item_rqids.view(-1, item_rqids.shape[-1]).cpu().numpy().tolist() + user_rqid = user_rqids.view(-1, user_rqids.shape[-1]).cpu().numpy().tolist()[0] + + text_rqids = {} + code = '' + for rqid in inters_rqids: + for k, idx in enumerate(rqid): + code = code + self.prefix[k].format(idx) + code = code + ', ' + text_rqids['inters'] = code[:-2] + code = '' + for k, idx in enumerate(user_rqid): + code = code + self.user_prefix[k].format(idx) + text_rqids['user'] = code + + inputs = inputs.format(inters = text_rqids['inters'], user = text_rqids['user']) + targets = targets.format(inters = text_rqids['inters'], user = text_rqids['user']) + + num_item = item_rec_embs.shape[0] + num_user = user_rec_emb.shape[0] + + elif task.lower() in ['intertitles2item','query2item']: + # inputs, targets, item, user + item_feature = self.item_texts[item]['title'] + ' ' + self.item_texts[item]['description'] + item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) + item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask) + item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1) + item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True) + item_emb = item_emb.detach() + + user_feature = " ".join(self.user_texts[user]) + user_ids = self.tokenizer(user_feature, return_tensors = 'pt', padding = True, truncation = True).to(self.model.device) + user_emb = llama_model(input_ids = user_ids.input_ids, attention_mask = user_ids.attention_mask) + user_emb = user_emb.last_hidden_state * user_ids.attention_mask.unsqueeze(-1) + user_emb = user_emb.sum(dim=1) / user_ids.attention_mask.sum(dim = -1, keepdim = True) + user_emb = user_emb.detach() + + item_rec_embs, item_rq_loss, item_rqids = self.item_rqvae(item_emb) + item_rqvae_loss, item_rec_loss = self.item_rqvae.compute_loss(item_rec_embs, item_rq_loss, item_emb) + + user_rec_emb, user_rq_loss, user_rqids = self.user_rqvae(user_emb) + user_rqvae_loss, user_rec_loss = self.user_rqvae.compute_loss(user_rec_emb, user_rq_loss, user_emb) + + item_rqid = item_rqids.view(-1, item_rqids.shape[-1]).cpu().numpy().tolist()[0] + user_rqid = user_rqids.view(-1, user_rqids.shape[-1]).cpu().numpy().tolist()[0] + + text_rqids = {} + code = '' + for k, idx in enumerate(item_rqid): + code = code + self.prefix[k].format(idx) + text_rqids['item'] = code + code = '' + for k, idx in enumerate(user_rqid): + code = code + self.user_prefix[k].format(idx) + text_rqids['user'] = code + + inputs = inputs.format(user = text_rqids['user']) + targets = targets.format(item = text_rqids['item'], user = text_rqids['user']) + + num_item = item_rec_embs.shape[0] + num_user = user_rec_emb.shape[0] + + elif task.lower() in ['item2index','index2item']: + # inputs, targets, item + item_feature = self.item_texts[item]['title'] + ' ' + self.item_texts[item]['description'] + item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) + item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask) + item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1) + item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True) + item_emb = item_emb.detach() + + rec_embs, rq_loss, rqids = self.item_rqvae(item_emb) + rqvae_loss, rec_loss = self.item_rqvae.compute_loss(rec_embs, rq_loss, item_emb) + + item_rqid = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[0] + code = '' + for k, idx in enumerate(item_rqid): + code = code + self.prefix[k].format(idx) + + if task.lower() == 'item2index': + targets = targets.format(item = code) + elif task.lower() == 'index2item': + inputs = inputs.format(item = code) + targets = targets.format(item = code) + else: + raise NotImplementedError + + item_rqvae_loss = rqvae_loss + user_rqvae_loss = 0 + num_item = rec_embs.shape[0] + num_user = 0 + + elif task.lower() == 'preferenceobtain': + # inputs, targets, inters + inter_feature_list = [] + inter_emb_list = [] + inter_item_list = inters.split(',') + for j in range(len(inter_item_list)): + inter_feature = self.item_texts[inter_item_list[j]]['title'] + ' ' + self.item_texts[inter_item_list[j]]['description'] + inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) + inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask) + inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1) + inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True) + inter_emb_list.append(inter_emb.detach()) + inter_embs = torch.cat(inter_emb_list, dim = 0) + + rec_embs, rq_loss, rqids = self.item_rqvae(inter_embs) + rqvae_loss, rec_loss = self.item_rqvae.compute_loss(rec_embs, rq_loss, inter_embs) + + inters_rqids = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist() + code = '' + for rqid in inters_rqids: + for k, idx in enumerate(rqid): + code = code + self.prefix[k].format(idx) + code = code + ', ' + code = code[:-2] + + inputs = inputs.format(inters = code) + targets = targets.format(inters = code) + + item_rqvae_loss = rqvae_loss + user_rqvae_loss = 0 + num_item = rec_embs.shape[0] + num_user = 0 + + elif task.lower() == 'usersearch': + # item, users, user + item_feature = self.item_texts[item]['title'] + ' ' + self.item_texts[item]['description'] + item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) + item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask) + item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1) + item_emb = item_emb.sum(dim = 1) / item_ids.attention_mask.sum(dim = -1, keepdim = True) + item_emb = item_emb.detach() + + users_emb_list = [] + users_list = users.split(',') + for j in range(len(users_list)): + u_feature = " ".join(self.user_texts[users_list[j]]) + u_id = self.tokenizer(u_feature, return_tensors = 'pt', padding = True, truncation = True).to(self.model.device) + u_emb = llama_model(input_ids = u_id.input_ids, attention_mask = u_id.attention_mask) + u_emb = u_emb.last_hidden_state * u_id.attention_mask.unsqueeze(-1) + u_emb = u_emb.sum(dim = 1) / u_id.attention_mask.sum(dim = -1, keepdim = True) + users_emb_list.append(u_emb.detach()) + users_emb = torch.cat(users_emb_list, dim = 0) + + user_feature = " ".join(self.user_texts[user]) + user_ids = self.tokenizer(user_feature, return_tensors = 'pt', padding = True, truncation = True).to(self.model.device) + user_emb = llama_model(input_ids = user_ids.input_ids, attention_mask = user_ids.attention_mask) + user_emb = user_emb.last_hidden_state * user_ids.attention_mask.unsqueeze(-1) + user_emb = user_emb.sum(dim = 1) / user_ids.attention_mask.sum(dim = -1, keepdim = True) + user_emb = user_emb.detach() + + item_rec_embs, item_rq_loss, item_rqids = self.item_rqvae(item_emb) + item_rqvae_loss, item_rec_loss = self.item_rqvae.compute_loss(item_rec_embs, item_rq_loss, item_emb) + + user_rec_emb, user_rq_loss, user_rqids = self.user_rqvae(torch.cat([users_emb, user_emb], dim = 0)) + user_rqvae_loss, user_rec_loss = self.user_rqvae.compute_loss(user_rec_emb, user_rq_loss, torch.cat([users_emb, user_emb], dim = 0)) + + item_rqid = item_rqids.view(-1, item_rqids.shape[-1]).cpu().numpy().tolist()[0] + users_rqids = user_rqids.view(-1, user_rqids.shape[-1]).cpu().numpy().tolist()[:-1] + user_rqid = user_rqids.view(-1, user_rqids.shape[-1]).cpu().numpy().tolist()[-1] + + text_rqids = {} + code = '' + for k, idx in enumerate(item_rqid): + code = code + self.prefix[k].format(idx) + text_rqids['item'] = code + code = '' + for rqid in users_rqids: + for k, idx in enumerate(rqid): + code = code + self.user_prefix[k].format(idx) + code = code + ', ' + text_rqids['users'] = code[:-2] + code = '' + for k, idx in enumerate(user_rqid): + code = code + self.user_prefix[k].format(idx) + text_rqids['user'] = code + + inputs = inputs.format(item = text_rqids['item'], users = text_rqids['users']) + targets = targets.format(item = text_rqids['item'], users = text_rqids['users'], user = text_rqids['user']) + + num_item = item_rec_embs.shape[0] + num_user = user_rec_emb.shape[0] + + elif task.lower() in ['pref2user','user2pref']: + # inputs, targets, user + user_feature = " ".join(self.user_texts[user]) + user_ids = self.tokenizer(user_feature, return_tensors = 'pt', padding = True, truncation = True).to(self.model.device) + user_emb = llama_model(input_ids = user_ids.input_ids, attention_mask = user_ids.attention_mask) + user_emb = user_emb.last_hidden_state * user_ids.attention_mask.unsqueeze(-1) + user_emb = user_emb.sum(dim = 1) / user_ids.attention_mask.sum(dim = -1, keepdim = True) + user_emb = user_emb.detach() + + user_rec_emb, user_rq_loss, user_rqids = self.user_rqvae(user_emb) + user_rqvae_loss, user_rec_loss = self.user_rqvae.compute_loss(user_rec_emb, user_rq_loss, user_emb) + + user_rqid = user_rqids.view(-1, user_rqids.shape[-1]).cpu().numpy().tolist()[0] + code = '' + for k, idx in enumerate(user_rqid): + code = code + self.user_prefix[k].format(idx) + + if task.lower() == 'pref2user': + targets = targets.format(user = code) + elif task.lower() == 'user2pref': + inputs = inputs.format(user = code) + targets = targets.format(user = code) + else: + raise NotImplementedError + + item_rqvae_loss = 0 + num_item = 0 + num_user = user_rec_emb.shape[0] + + else: + raise NotImplementedError + + return inputs, targets, item_rqvae_loss, user_rqvae_loss, num_item, num_user + + def forward(self, input_ids, labels, inters, item, users, user, task): + ''' + 'input_ids': + [ + "Below is an instruction that describes a task. Write a response that appropriately completes the request. + ### Instruction: + Using the user's historical interactions as input data, suggest the next item that the user is highly likely to enjoy. + The historical interactions are provided as follows: {inters}. + ### Response:", + + 'Below is an instruction that describes a task. Write a response that appropriately completes the request. + ### Instruction: + You have obtained the ordered list of user historical interaction items, which is as follows: {inters}. + Using this history as a reference, please select the next item to recommend to the user. + ### Response:' + ], + + 'labels': + [ + "Below is an instruction that describes a task. Write a response that appropriately completes the request. + ### Instruction: + Using the user's historical interactions as input data, suggest the next item that the user is highly likely to enjoy. + The historical interactions are provided as follows: {inters}. + ### Response:{item}", + + 'Below is an instruction that describes a task. Write a response that appropriately completes the request. + ### Instruction: + You have obtained the ordered list of user historical interaction items, which is as follows: {inters}. + Using this history as a reference, please select the next item to recommend to the user. + ### Response:{item}' + ], + + 'inters': ['0', '0,1'], + 'item': ['1', '2'], + 'task': ['seqrec', 'seqrec'] + ''' + assert len(set([len(input_ids), len(labels), len(inters), len(item), len(user), len(task), len(users)])) == 1 + num_data = len(task) + + total_item_rqvae_loss, total_user_rqvae_loss = 0, 0 + total_num_item, total_num_user = 1e-8, 1e-8 + for i in range(num_data): + input_ids[i], labels[i], item_rqvae_loss, user_rqvae_loss, num_item, num_user = self.rqvae_forward( + input_ids[i], labels[i], inters[i], item[i], users[i], user[i], task[i]) + total_item_rqvae_loss = total_item_rqvae_loss + item_rqvae_loss * num_item + total_user_rqvae_loss = total_user_rqvae_loss + user_rqvae_loss * num_user + total_num_item += num_item + total_num_user += num_user + + input_data = self.tokenizer( + text = labels, + text_target = input_ids, + return_tensors = 'pt', + padding = 'longest', + truncation = True, + max_length = self.tokenizer.model_max_length, + return_attention_mask = True + ).to(self.model.device) + + labels = copy.deepcopy(input_data["input_ids"]) + if self.args['only_train_response']: + labels[labels == self.tokenizer.pad_token_id] = -100 + labels[torch.where(input_data["labels"] != self.tokenizer.pad_token_id)] = -100 + + input_data["labels"] = labels + + # codebook_embedding = [] + # for i in range(len(self.item_rqvae.num_emb_list)): + # codebook_embedding.append(self.item_rqvae.rq.vq_layers[i].embedding.weight.data) + # for i in range(len(self.user_rqvae.num_emb_list)): + # codebook_embedding.append(self.user_rqvae.rq.vq_layers[i].embedding.weight.data) + # codebook_embedding = torch.cat(codebook_embedding, dim = 0) + # codebook_embedding = self.projector(codebook_embedding) + # self.model.model.model.embed_tokens.weight.data[-codebook_embedding.shape[0]:] = codebook_embedding + + # input_data: dict_keys(['input_ids', 'attention_mask', 'labels']) + result = self.model(**input_data) + # wandb.log({'Llama_Loss': result.loss, 'RQVAE_Loss': total_rqvae_loss / total_num_sample}) + result.loss = result.loss + total_item_rqvae_loss / total_num_item + total_user_rqvae_loss / total_num_user + # wandb.log({'Total_Loss': result.loss}) + return result + + def floating_point_ops(self, inputs): + return 0 \ No newline at end of file diff --git a/run.sh b/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..03b9c6f6279c82ce8e169a1d971ff19dad42698c --- /dev/null +++ b/run.sh @@ -0,0 +1,98 @@ +export WANDB_MODE=disabled +export CUDA_LAUNCH_BLOCKING=1 + +DATASET=Games +BASE_MODEL= huggyllama/llama-7b +DATA_PATH=./data +OUTPUT_DIR=./ckpt/$DATASET/ + +torchrun --nproc_per_node=8 --master_port=3324 finetune.py \ + --base_model $BASE_MODEL \ + --output_dir $OUTPUT_DIR \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --per_device_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --learning_rate 5e-5 \ + --epochs 4 \ + --weight_decay 0.01 \ + --save_and_eval_strategy epoch \ + --deepspeed ./config/ds_z3_bf16.json \ + --bf16 \ + --only_train_response \ + --tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \ + --train_prompt_sample_num 1,1,1,1,1,1 \ + --train_data_sample_num 0,0,0,100000,0,0 \ + --index_file .index.json + + +cd convert +nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 & +cd .. + + + + + + +DATASET=Arts +BASE_MODEL= huggyllama/llama-7b +DATA_PATH=./data +OUTPUT_DIR=./ckpt/$DATASET/ + +torchrun --nproc_per_node=8 --master_port=3324 finetune.py \ + --base_model $BASE_MODEL \ + --output_dir $OUTPUT_DIR \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --per_device_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --learning_rate 5e-5 \ + --epochs 4 \ + --weight_decay 0.01 \ + --save_and_eval_strategy epoch \ + --deepspeed ./config/ds_z3_bf16.json \ + --bf16 \ + --only_train_response \ + --tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \ + --train_prompt_sample_num 1,1,1,1,1,1 \ + --train_data_sample_num 0,0,0,30000,0,0 \ + --index_file .index.json + + +cd convert +nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 & +cd .. + + + + + +DATASET=Instruments +BASE_MODEL= huggyllama/llama-7b +DATA_PATH=./data +OUTPUT_DIR=./ckpt/$DATASET/ + +torchrun --nproc_per_node=8 --master_port=3324 finetune.py \ + --base_model $BASE_MODEL \ + --output_dir $OUTPUT_DIR \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --per_device_batch_size 8 \ + --gradient_accumulation_steps 2 \ + --learning_rate 5e-5 \ + --epochs 4 \ + --weight_decay 0.01 \ + --save_and_eval_strategy epoch \ + --deepspeed ./config/ds_z3_bf16.json \ + --bf16 \ + --only_train_response \ + --tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \ + --train_prompt_sample_num 1,1,1,1,1,1 \ + --train_data_sample_num 0,0,0,20000,0,0 \ + --index_file .index.json + + +cd convert +nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 & +cd .. diff --git a/run_test.sh b/run_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..1a8776833f95c9813a3ee4daeca7e6e66be0a609 --- /dev/null +++ b/run_test.sh @@ -0,0 +1,17 @@ + + +DATASET=Games +DATA_PATH=./data +OUTPUT_DIR=./ckpt/$DATASET/ +RESULTS_FILE=./results/$DATASET/xxx.json + +python test.py \ + --gpu_id 0 \ + --ckpt_path $CKPT_PATH \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --results_file $RESULTS_FILE \ + --test_batch_size 1 \ + --num_beams 20 \ + --test_prompt_ids all \ + --index_file .index.json diff --git a/run_test_ddp.sh b/run_test_ddp.sh new file mode 100644 index 0000000000000000000000000000000000000000..efa9459202985bf9460af1a71d82cced6001c816 --- /dev/null +++ b/run_test_ddp.sh @@ -0,0 +1,15 @@ +DATASET=Instruments +DATA_PATH=${datain}/v-yinju/rqvae-zzx/data +CKPT_PATH=${datain}/v-yinju/rq-llama/v3-train/$DATASET/second/finetune +RESULTS_FILE=$CKPT_PATH/eval_result.json + +torchrun --nproc_per_node=2 --master_port=4324 evaluate-finetuned.py \ + --ckpt_path $CKPT_PATH \ + --dataset $DATASET \ + --data_path $DATA_PATH \ + --results_file $RESULTS_FILE \ + --test_batch_size 1 \ + --num_beams 20 \ + --test_prompt_ids all \ + --test_task seqrec \ + --index_file $CKPT_PATH/indices.json \ No newline at end of file diff --git a/test-main.py b/test-main.py new file mode 100644 index 0000000000000000000000000000000000000000..60a38fbe849ccd382d945a92eabf67564fa9438e --- /dev/null +++ b/test-main.py @@ -0,0 +1,184 @@ +import argparse +import json +import os +import sys + +import torch +import transformers +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel +from peft import PeftModel +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig + +from utils import * +from collator import VanillaCollator, TestCollator +from prompt import all_prompt +from evaluate import get_topk_results, get_metrics_results +from rq_llama import * + +parser = argparse.ArgumentParser(description="RQ-Llama Evaluation") +parser = parse_global_args(parser) +parser = parse_dataset_args(parser) +parser = parse_test_args(parser) +args = parser.parse_args() + +set_seed(args.seed) +world_size = int(os.environ.get("WORLD_SIZE", 1)) +local_rank = int(os.environ.get("LOCAL_RANK") or 0) +torch.cuda.set_device(local_rank) +if local_rank == 0: + print(vars(args)) + +dist.init_process_group(backend = "nccl", world_size = world_size, rank = local_rank) + +device_map = {"": local_rank} +device = torch.device("cuda",local_rank) + +rqllama = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map) +rqllama = DistributedDataParallel(rqllama, device_ids = [local_rank]) + +if args.test_prompt_ids == "all": + if args.test_task.lower() == "seqrec": + prompt_ids = range(len(all_prompt["seqrec"])) + elif args.test_task.lower() == "itemsearch": + prompt_ids = range(len(all_prompt["itemsearch"])) + elif args.test_task.lower() == "fusionseqrec": + prompt_ids = range(len(all_prompt["fusionseqrec"])) +else: + prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")] + +test_data = load_test_dataset(args) +if local_rank == 0: + print("data num:", len(test_data)) +ddp_sampler = DistributedSampler(test_data, num_replicas = world_size, rank = local_rank, drop_last = True) +collator = TestCollator(args, rqllama.module.tokenizer) +all_items = test_data.get_all_items() +# print('num_items:', len(all_items)) +prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(rqllama.module.tokenizer) + +test_loader = DataLoader( + test_data, + batch_size = args.test_batch_size, + collate_fn = collator, + sampler = ddp_sampler, + num_workers = 2, + pin_memory = True +) + +rqllama.eval() + +metrics = args.metrics.split(",") +all_prompt_results = [] + +with torch.no_grad(): + for prompt_id in prompt_ids: + if local_rank == 0: + print("Start prompt: ",prompt_id) + + test_loader.dataset.set_prompt(prompt_id) + metrics_results = {} + total = 0 + for step, batch in enumerate(tqdm(test_loader)): + inputs = batch[0].to(device) + targets = batch[1] + bs = len(targets) + num_beams = args.num_beams + + while True: + try: + output = rqllama.module.model.generate( + input_ids = inputs["input_ids"], + attention_mask = inputs["attention_mask"], + max_new_tokens = 10, + prefix_allowed_tokens_fn = prefix_allowed_tokens, + num_beams = num_beams, + num_return_sequences = num_beams, + output_scores = True, + return_dict_in_generate = True, + early_stopping = True, + ) + break + except torch.cuda.OutOfMemoryError as e: + print("Out of memory!") + num_beams = num_beams -1 + print("Beam:", num_beams) + except Exception: + raise RuntimeError + + output_ids = output["sequences"] + scores = output["sequences_scores"] + + output = rqllama.module.tokenizer.batch_decode(output_ids, skip_special_tokens = True) + topk_res = get_topk_results( + output, + scores, + targets, + num_beams, + all_items = all_items if args.filter_items else None + ) + + bs_gather_list = [None for _ in range(world_size)] + dist.all_gather_object(obj = bs, object_list = bs_gather_list) + total += sum(bs_gather_list) + res_gather_list = [None for _ in range(world_size)] + dist.all_gather_object(obj = topk_res, object_list = res_gather_list) + + if local_rank == 0: + all_device_topk_res = [] + for ga_res in res_gather_list: + all_device_topk_res += ga_res + batch_metrics_res = get_metrics_results(all_device_topk_res, metrics) + for m, res in batch_metrics_res.items(): + if m not in metrics_results: + metrics_results[m] = res + else: + metrics_results[m] += res + if (step + 1) % 50 == 0: + temp = {} + for m in metrics_results: + temp[m] = metrics_results[m] / total + print(temp) + dist.barrier() + + if local_rank == 0: + for m in metrics_results: + metrics_results[m] = metrics_results[m] / total + + all_prompt_results.append(metrics_results) + print("======================================================") + print("Prompt {} results: ".format(prompt_id), metrics_results) + print("======================================================") + print("") + dist.barrier() +dist.barrier() + +if local_rank == 0: + mean_results = {} + min_results = {} + max_results = {} + + for m in metrics: + all_res = [_[m] for _ in all_prompt_results] + mean_results[m] = sum(all_res) / len(all_res) + min_results[m] = min(all_res) + max_results[m] = max(all_res) + + print("======================================================") + print("Mean results: ", mean_results) + print("Min results: ", min_results) + print("Max results: ", max_results) + print("======================================================") + + save_data={} + save_data["test_prompt_ids"] = args.test_prompt_ids + save_data["mean_results"] = mean_results + save_data["min_results"] = min_results + save_data["max_results"] = max_results + save_data["all_prompt_results"] = all_prompt_results + + with open(args.results_file, "w") as f: + json.dump(save_data, f, indent = 4) + print("Save file: ", args.results_file) \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000000000000000000000000000000000000..8ff1217b58487baecadaf39fef4bd26e365d9383 --- /dev/null +++ b/test.py @@ -0,0 +1,175 @@ +import argparse +import json +import os +import sys +from typing import List + +import torch +import transformers +from peft import PeftModel +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig + +from utils import * +from collator import TestCollator +from prompt import all_prompt +from evaluate import get_topk_results, get_metrics_results + + +def test(args): + + set_seed(args.seed) + print(vars(args)) + + device_map = {"": args.gpu_id} + device = torch.device("cuda",args.gpu_id) + + + tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path) + if args.lora: + model = LlamaForCausalLM.from_pretrained( + args.base_model, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + device_map=device_map, + ) + model.resize_token_embeddings(len(tokenizer)) + model = PeftModel.from_pretrained( + model, + args.ckpt_path, + torch_dtype=torch.bfloat16, + device_map=device_map, + ) + else: + model = LlamaForCausalLM.from_pretrained( + args.ckpt_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + device_map=device_map, + ) + # assert model.config.vocab_size == len(tokenizer) + + if args.test_prompt_ids == "all": + if args.test_task.lower() == "seqrec": + prompt_ids = range(len(all_prompt["seqrec"])) + elif args.test_task.lower() == "itemsearch": + prompt_ids = range(len(all_prompt["itemsearch"])) + elif args.test_task.lower() == "fusionseqrec": + prompt_ids = range(len(all_prompt["fusionseqrec"])) + else: + prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")] + + test_data = load_test_dataset(args) + collator = TestCollator(args, tokenizer) + all_items = test_data.get_all_items() + + + prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(tokenizer) + + test_loader = DataLoader(test_data, batch_size=args.test_batch_size, collate_fn=collator, + shuffle=True, num_workers=4, pin_memory=True) + + + print("data num:", len(test_data)) + + model.eval() + + metrics = args.metrics.split(",") + all_prompt_results = [] + with torch.no_grad(): + for prompt_id in prompt_ids: + + test_loader.dataset.set_prompt(prompt_id) + metrics_results = {} + total = 0 + + for step, batch in enumerate(tqdm(test_loader)): + inputs = batch[0].to(device) + targets = batch[1] + total += len(targets) + + output = model.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_new_tokens=10, + # max_length=10, + prefix_allowed_tokens_fn=prefix_allowed_tokens, + num_beams=args.num_beams, + num_return_sequences=args.num_beams, + output_scores=True, + return_dict_in_generate=True, + early_stopping=True, + ) + output_ids = output["sequences"] + scores = output["sequences_scores"] + + output = tokenizer.batch_decode( + output_ids, skip_special_tokens=True + ) + # print(output) + topk_res = get_topk_results(output,scores,targets,args.num_beams, + all_items=all_items if args.filter_items else None) + + batch_metrics_res = get_metrics_results(topk_res, metrics) + # print(batch_metrics_res) + + for m, res in batch_metrics_res.items(): + if m not in metrics_results: + metrics_results[m] = res + else: + metrics_results[m] += res + + if (step+1)%10 == 0: + temp={} + for m in metrics_results: + temp[m] = metrics_results[m] / total + print(temp) + + for m in metrics_results: + metrics_results[m] = metrics_results[m] / total + + all_prompt_results.append(metrics_results) + print("======================================================") + print("Prompt {} results: ".format(prompt_id), metrics_results) + print("======================================================") + print("") + + mean_results = {} + min_results = {} + max_results = {} + + for m in metrics: + all_res = [_[m] for _ in all_prompt_results] + mean_results[m] = sum(all_res)/len(all_res) + min_results[m] = min(all_res) + max_results[m] = max(all_res) + + print("======================================================") + print("Mean results: ", mean_results) + print("Min results: ", min_results) + print("Max results: ", max_results) + print("======================================================") + + + save_data={} + save_data["test_prompt_ids"] = args.test_prompt_ids + save_data["mean_results"] = mean_results + save_data["min_results"] = min_results + save_data["max_results"] = max_results + save_data["all_prompt_results"] = all_prompt_results + + with open(args.results_file, "w") as f: + json.dump(save_data, f, indent=4) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="LLMRec_test") + parser = parse_global_args(parser) + parser = parse_dataset_args(parser) + parser = parse_test_args(parser) + + args = parser.parse_args() + + test(args) diff --git a/test_ddp.py b/test_ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..3676ea9a53ce26209bae720aeaecc50ddb9e943e --- /dev/null +++ b/test_ddp.py @@ -0,0 +1,218 @@ +import argparse +import json +import os +import sys + +import torch +import transformers +import torch.distributed as dist +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel +from peft import PeftModel +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig + +from utils import * +from collator import TestCollator +from prompt import all_prompt +from evaluate import get_topk_results, get_metrics_results + + +def test_ddp(args): + + set_seed(args.seed) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + local_rank = int(os.environ.get("LOCAL_RANK") or 0) + torch.cuda.set_device(local_rank) + if local_rank == 0: + print(vars(args)) + + dist.init_process_group(backend="nccl", world_size=world_size, rank=local_rank) + + device_map = {"": local_rank} + device = torch.device("cuda",local_rank) + + tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path) + if args.lora: + model = LlamaForCausalLM.from_pretrained( + args.base_model, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + device_map=device_map, + ) + model.resize_token_embeddings(len(tokenizer)) + model = PeftModel.from_pretrained( + model, + args.ckpt_path, + torch_dtype=torch.bfloat16, + device_map=device_map, + ) + else: + model = LlamaForCausalLM.from_pretrained( + args.ckpt_path, + torch_dtype=torch.bfloat16, + low_cpu_mem_usage=True, + device_map=device_map, + ) + # assert model.config.vocab_size == len(tokenizer) + model = DistributedDataParallel(model, device_ids=[local_rank]) + + if args.test_prompt_ids == "all": + if args.test_task.lower() == "seqrec": + prompt_ids = range(len(all_prompt["seqrec"])) + elif args.test_task.lower() == "itemsearch": + prompt_ids = range(len(all_prompt["itemsearch"])) + elif args.test_task.lower() == "fusionseqrec": + prompt_ids = range(len(all_prompt["fusionseqrec"])) + else: + prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")] + + test_data = load_test_dataset(args) + ddp_sampler = DistributedSampler(test_data, num_replicas=world_size, rank=local_rank, drop_last=True) + + test_data = load_test_dataset(args) + collator = TestCollator(args, tokenizer) + all_items = test_data.get_all_items() + + + prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(tokenizer) + + + test_loader = DataLoader(test_data, batch_size=args.test_batch_size, collate_fn=collator, + sampler=ddp_sampler, num_workers=2, pin_memory=True) + + if local_rank == 0: + print("data num:", len(test_data)) + + model.eval() + + metrics = args.metrics.split(",") + all_prompt_results = [] + with torch.no_grad(): + + for prompt_id in prompt_ids: + + if local_rank == 0: + print("Start prompt: ",prompt_id) + + test_loader.dataset.set_prompt(prompt_id) + metrics_results = {} + total = 0 + + for step, batch in enumerate(tqdm(test_loader)): + inputs = batch[0].to(device) + targets = batch[1] + bs = len(targets) + num_beams = args.num_beams + while True: + try: + output = model.module.generate( + input_ids=inputs["input_ids"], + attention_mask=inputs["attention_mask"], + max_new_tokens=10, + prefix_allowed_tokens_fn=prefix_allowed_tokens, + num_beams=num_beams, + num_return_sequences=num_beams, + output_scores=True, + return_dict_in_generate=True, + early_stopping=True, + ) + break + except torch.cuda.OutOfMemoryError as e: + print("Out of memory!") + num_beams = num_beams -1 + print("Beam:", num_beams) + except Exception: + raise RuntimeError + + output_ids = output["sequences"] + scores = output["sequences_scores"] + + output = tokenizer.batch_decode( + output_ids, skip_special_tokens=True + ) + + topk_res = get_topk_results(output, scores, targets, num_beams, + all_items=all_items if args.filter_items else None) + + bs_gather_list = [None for _ in range(world_size)] + dist.all_gather_object(obj=bs, object_list=bs_gather_list) + total += sum(bs_gather_list) + res_gather_list = [None for _ in range(world_size)] + dist.all_gather_object(obj=topk_res, object_list=res_gather_list) + + + if local_rank == 0: + all_device_topk_res = [] + for ga_res in res_gather_list: + all_device_topk_res += ga_res + batch_metrics_res = get_metrics_results(all_device_topk_res, metrics) + for m, res in batch_metrics_res.items(): + if m not in metrics_results: + metrics_results[m] = res + else: + metrics_results[m] += res + + if (step + 1) % 50 == 0: + temp = {} + for m in metrics_results: + temp[m] = metrics_results[m] / total + print(temp) + + dist.barrier() + + if local_rank == 0: + for m in metrics_results: + metrics_results[m] = metrics_results[m] / total + + all_prompt_results.append(metrics_results) + print("======================================================") + print("Prompt {} results: ".format(prompt_id), metrics_results) + print("======================================================") + print("") + + dist.barrier() + + dist.barrier() + + if local_rank == 0: + mean_results = {} + min_results = {} + max_results = {} + + for m in metrics: + all_res = [_[m] for _ in all_prompt_results] + mean_results[m] = sum(all_res)/len(all_res) + min_results[m] = min(all_res) + max_results[m] = max(all_res) + + print("======================================================") + print("Mean results: ", mean_results) + print("Min results: ", min_results) + print("Max results: ", max_results) + print("======================================================") + + + save_data={} + save_data["test_prompt_ids"] = args.test_prompt_ids + save_data["mean_results"] = mean_results + save_data["min_results"] = min_results + save_data["max_results"] = max_results + save_data["all_prompt_results"] = all_prompt_results + + with open(args.results_file, "w") as f: + json.dump(save_data, f, indent=4) + print("Save file: ", args.results_file) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="LLMRec_test") + parser = parse_global_args(parser) + parser = parse_dataset_args(parser) + parser = parse_test_args(parser) + + args = parser.parse_args() + + test_ddp(args) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..456f19d004611692e7299c82648d0992e10bc7f7 --- /dev/null +++ b/utils.py @@ -0,0 +1,452 @@ +import json +import logging +import os +import random +import datetime + +import numpy as np +import torch +from torch.utils.data import ConcatDataset +from data import * +# from data import SeqRecDataset, ItemFeatDataset, ItemSearchDataset, FusionSeqRecDataset, SeqRecTestDataset, PreferenceObtainDataset +from data_finetune import * +# from data_finetune import SeqRecFinetune, ItemFeatFinetune, ItemSearchFinetune, FusionSeqRecFinetune, PreferenceObtainFinetune + +def parse_evaluate_args(parser): + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--base_model", type=str, default="../llama-7b/", help="basic model path") + parser.add_argument("--output_dir", type=str, default="./ckpt/", help="The output directory") + + parser.add_argument("--data_path", type=str, default="", + help="data directory") + parser.add_argument("--tasks", type=str, + default='seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item', + help="Downstream tasks, separate by comma") + parser.add_argument("--train_data_sample_num", type=str, default="0,0,0,0,0,0,0,0,0", + help="the number of sampling data for each task") + parser.add_argument("--dataset", type=str, default="Instruments", help="Dataset name") + parser.add_argument("--index_file", type=str, default=".index.item.json", help="the item indices file") + parser.add_argument("--user_index_file", type=str, default=".index.user.json", help="the item indices file") + parser.add_argument("--dataloader_num_workers", type=int, default=0, help="dataloader num_workers") + parser.add_argument("--dataloader_prefetch_factor", type=int, default=2, help="dataloader prefetch_factor") + + # arguments related to sequential task + parser.add_argument("--max_his_len", type=int, default=20, + help="the max number of items in history sequence, -1 means no limit") + parser.add_argument("--add_prefix", action="store_true", default=False, + help="whether add sequential prefix in history") + parser.add_argument("--his_sep", type=str, default=", ", help="The separator used for history") + parser.add_argument("--only_train_response", action="store_true", default=False, + help="whether only train on responses") + + parser.add_argument("--train_prompt_sample_num", type=str, default="1,1,1,1,1,1,1,1,1", + help="the number of sampling prompts for each task") + + parser.add_argument("--valid_prompt_id", type=int, default=0, + help="The prompt used for validation") + parser.add_argument("--sample_valid", action="store_true", default=True, + help="use sampled prompt for validation") + parser.add_argument("--valid_prompt_sample_num", type=int, default=2, + help="the number of sampling validation sequential recommendation prompts") + + parser.add_argument("--ckpt_path", type=str, default="", help="The checkpoint path") + parser.add_argument("--lora", action="store_true", default=False) + parser.add_argument("--filter_items", action="store_true", default=False, + help="whether filter illegal items") + + parser.add_argument("--results_file", type=str, default="./results/test-ddp.json", help="result output path") + + parser.add_argument("--test_batch_size", type=int, default=1) + parser.add_argument("--num_beams", type=int, default=20) + parser.add_argument("--sample_num", type=int, default=-1, + help="test sample number, -1 represents using all test data") + parser.add_argument("--gpu_id", type=int, default=0, + help="GPU ID when testing with single GPU") + parser.add_argument("--test_prompt_ids", type=str, default="0", + help="test prompt ids, separate by comma. 'all' represents using all") + parser.add_argument("--metrics", type=str, default="hit@1,hit@5,hit@10,ndcg@5,ndcg@10", + help="test metrics, separate by comma") + parser.add_argument("--test_task", type=str, default="SeqRec") + + return parser + +def parse_finetune_args(parser): + parser.add_argument("--seed", type=int, default=42, help="Random seed") + parser.add_argument("--base_model", type=str, default="../llama-7b/", help="basic model path") + + parser.add_argument("--output_dir", type=str, default="./ckpt/", help="The output directory") + + parser.add_argument("--data_path", type=str, default="", + help="data directory") + parser.add_argument("--tasks", type=str, + default='seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item', + help="Downstream tasks, separate by comma") + parser.add_argument("--train_data_sample_num", type=str, default="0,0,0,0,0,0,0,0,0", + help="the number of sampling data for each task") + parser.add_argument("--dataset", type=str, default="Instruments", help="Dataset name") + parser.add_argument("--index_file", type=str, default=".index.json", help="item indices file") + parser.add_argument("--user_index_file", type=str, default=".user-index.json", help="user indices file") + + parser.add_argument("--dataloader_num_workers", type=int, default=0, help="dataloader num_workers") + parser.add_argument("--dataloader_prefetch_factor", type=int, default=2, help="dataloader prefetch_factor") + + parser.add_argument("--max_his_len", type=int, default=20, + help="the max number of items in history sequence, -1 means no limit") + parser.add_argument("--add_prefix", action="store_true", default=False, + help="whether add sequential prefix in history") + parser.add_argument("--his_sep", type=str, default=", ", help="The separator used for history") + parser.add_argument("--only_train_response", action="store_true", default=False, + help="whether only train on responses") + + parser.add_argument("--train_prompt_sample_num", type=str, default="1,1,1,1,1,1,1,1,1", + help="the number of sampling prompts for each task") + + parser.add_argument("--valid_prompt_id", type=int, default=0, + help="The prompt used for validation") + parser.add_argument("--sample_valid", action="store_true", default=True, + help="use sampled prompt for validation") + parser.add_argument("--valid_prompt_sample_num", type=int, default=2, + help="the number of sampling validation sequential recommendation prompts") + + parser.add_argument("--optim", type=str, default="adamw_torch", help='The name of the optimizer') + parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--per_device_batch_size", type=int, default=8) + parser.add_argument("--gradient_accumulation_steps", type=int, default=2) + parser.add_argument("--logging_step", type=int, default=10) + parser.add_argument("--model_max_length", type=int, default=2048) + parser.add_argument("--weight_decay", type=float, default=0.01) + + parser.add_argument("--lora_r", type=int, default=8) + parser.add_argument("--lora_alpha", type=int, default=32) + parser.add_argument("--lora_dropout", type=float, default=0.05) + parser.add_argument("--lora_target_modules", type=str, + default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj", help="separate by comma") + parser.add_argument("--lora_modules_to_save", type=str, + default="embed_tokens,lm_head", help="separate by comma") + + parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="either training checkpoint or final adapter") + + parser.add_argument("--warmup_ratio", type=float, default=0.01) + parser.add_argument("--lr_scheduler_type", type=str, default="cosine") + parser.add_argument("--save_and_eval_strategy", type=str, default="epoch") + parser.add_argument("--save_and_eval_steps", type=int, default=1000) + parser.add_argument("--fp16", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=False) + parser.add_argument("--deepspeed", type=str, default="./config/ds_z3_bf16.json") + parser.add_argument("--remove_unused_columns", action="store_true", default=False, help='if remove unused columns') + + parser.add_argument("--reindex", type = int, default = 0) + # parser.add_argument("--user_reindex", type = int, default = 0) + parser.add_argument("--ckpt_path", type=str, default="") + + return parser + +def load_finetune_datasets(args): + + tasks = args.tasks.split(",") + + train_prompt_sample_num = [int(_) for _ in args.train_prompt_sample_num.split(",")] + assert len(tasks) == len(train_prompt_sample_num), "prompt sample number does not match task number" + train_data_sample_num = [int(_) for _ in args.train_data_sample_num.split(",")] + assert len(tasks) == len(train_data_sample_num), "data sample number does not match task number" + + train_datasets = [] + for task, prompt_sample_num,data_sample_num in zip(tasks,train_prompt_sample_num,train_data_sample_num): + if task.lower() == "seqrec": + dataset = SeqRecFinetune(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() == "item2index" or task.lower() == "index2item": + dataset = ItemFeatFinetune(args, task=task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() == "fusionseqrec": + dataset = FusionSeqRecFinetune(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() == "itemsearch": + dataset = ItemSearchFinetune(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() == "preferenceobtain": + dataset = PreferenceObtainFinetune(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() == "usersearch": + dataset = UserSearchFinetune(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() in ["user2pref", "pref2user"]: + dataset = UserFeatFinetune(args, task = task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + else: + raise NotImplementedError + train_datasets.append(dataset) + + train_data = ConcatDataset(train_datasets) + + valid_data = SeqRecFinetune(args, "valid", args.valid_prompt_sample_num) + + return train_data, valid_data + +# def load_finetune_datasets(args): +# tasks = args.tasks.split(",") +# train_prompt_sample_num = [int(_) for _ in args.train_prompt_sample_num.split(",")] +# assert len(tasks) == len(train_prompt_sample_num), "prompt sample number does not match task number" +# train_data_sample_num = [int(_) for _ in args.train_data_sample_num.split(",")] +# assert len(tasks) == len(train_data_sample_num), "data sample number does not match task number" + +# train_datasets = [] +# for task, prompt_sample_num,data_sample_num in zip(tasks,train_prompt_sample_num,train_data_sample_num): +# if task.lower() == "seqrec": +# dataset = SeqRecFinetune(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + +# elif task.lower() == "item2index" or task.lower() == "index2item": +# dataset = ItemFeatFinetune(args, task=task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + +# elif task.lower() in ["inters2title", "inters2description", "intertitles2item"]: +# dataset = FusionSeqRecFinetune(args, task=task.lower(), mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + +# elif task.lower() in ["itemsearch", "query2item"]: +# dataset = ItemSearchFinetune(args, task=task.lower(),mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + +# elif task.lower() == "preferenceobtain": +# dataset = PreferenceObtainFinetune(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + +# else: +# raise NotImplementedError +# train_datasets.append(dataset) + +# train_data = ConcatDataset(train_datasets) + +# valid_data = SeqRecDataset(args,"valid",args.valid_prompt_sample_num) + +# return train_data, valid_data + +def parse_global_args(parser): + parser.add_argument("--seed", type=int, default=42, help="Random seed") + + parser.add_argument("--base_model", type=str, + default="../llama-7b/", + help="basic model path") + parser.add_argument("--output_dir", type=str, + default="./ckpt/", + help="The output directory") + return parser + +def parse_dataset_args(parser): + parser.add_argument("--data_path", type=str, default="", + help="data directory") + parser.add_argument("--tasks", type=str, + default='seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item', + help="Downstream tasks, separate by comma") + parser.add_argument("--train_data_sample_num", type=str, default="0,0,0,0,0,0,0,0,0", + help="the number of sampling data for each task") + parser.add_argument("--dataset", type=str, default="Instruments", help="Dataset name") + parser.add_argument("--index_file", type=str, default=".index.json", help="the item indices file") + parser.add_argument("--user_index_file", type=str, default=".user-index.json", help="user indices file") + parser.add_argument("--dataloader_num_workers", type=int, default=0, help="dataloader num_workers") + parser.add_argument("--dataloader_prefetch_factor", type=int, default=2, help="dataloader prefetch_factor") + + # arguments related to sequential task + parser.add_argument("--max_his_len", type=int, default=20, + help="the max number of items in history sequence, -1 means no limit") + parser.add_argument("--add_prefix", action="store_true", default=False, + help="whether add sequential prefix in history") + parser.add_argument("--his_sep", type=str, default=", ", help="The separator used for history") + parser.add_argument("--only_train_response", action="store_true", default=False, + help="whether only train on responses") + + parser.add_argument("--train_prompt_sample_num", type=str, default="1,1,1,1,1,1,1,1,1", + help="the number of sampling prompts for each task") + + parser.add_argument("--valid_prompt_id", type=int, default=0, + help="The prompt used for validation") + parser.add_argument("--sample_valid", action="store_true", default=True, + help="use sampled prompt for validation") + parser.add_argument("--valid_prompt_sample_num", type=int, default=2, + help="the number of sampling validation sequential recommendation prompts") + + return parser + +def parse_train_args(parser): + parser.add_argument("--optim", type=str, default="adamw_torch", help='The name of the optimizer') + parser.add_argument("--epochs", type=int, default=4) + parser.add_argument("--learning_rate", type=float, default=2e-5) + parser.add_argument("--per_device_batch_size", type=int, default=8) + parser.add_argument("--gradient_accumulation_steps", type=int, default=2) + parser.add_argument("--logging_step", type=int, default=10) + parser.add_argument("--model_max_length", type=int, default=2048) + parser.add_argument("--weight_decay", type=float, default=0.01) + + parser.add_argument("--lora_r", type=int, default=8) + parser.add_argument("--lora_alpha", type=int, default=32) + parser.add_argument("--lora_dropout", type=float, default=0.05) + parser.add_argument("--lora_target_modules", type=str, + default="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj", help="separate by comma") + parser.add_argument("--lora_modules_to_save", type=str, + default="embed_tokens,lm_head", help="separate by comma") + + parser.add_argument("--resume_from_checkpoint", type=str, default=None, help="either training checkpoint or final adapter") + + parser.add_argument("--warmup_ratio", type=float, default=0.01) + parser.add_argument("--lr_scheduler_type", type=str, default="cosine") + parser.add_argument("--save_and_eval_strategy", type=str, default="epoch") + parser.add_argument("--save_and_eval_steps", type=int, default=1000) + parser.add_argument("--fp16", action="store_true", default=False) + parser.add_argument("--bf16", action="store_true", default=False) + parser.add_argument("--deepspeed", type=str, default="./config/ds_z3_bf16.json") + parser.add_argument("--remove_unused_columns", action="store_true", default=False, help='if remove unused columns') + + return parser + +def parse_rqvae_args(parser): + parser.add_argument('--lr', type=float, default=1e-3, help='learning rate') + # parser.add_argument('--epochs', type=int, default=5000, help='number of epochs') + parser.add_argument('--batch_size', type=int, default=1024, help='batch size') + parser.add_argument('--num_workers', type=int, default=4, ) + parser.add_argument('--eval_step', type=int, default=50, help='eval step') + parser.add_argument('--learner', type=str, default="AdamW", help='optimizer') + # parser.add_argument("--data_path", type=str, + # default="../data/Games/Games.emb-llama-td.npy", + # help="Input data path.") + + # parser.add_argument('--weight_decay', type=float, default=1e-4, help='l2 regularization weight') + parser.add_argument("--dropout_prob", type=float, default=0.0, help="dropout ratio") + parser.add_argument("--bn", type=bool, default=False, help="use bn or not") + parser.add_argument("--loss_type", type=str, default="mse", help="loss_type") + parser.add_argument("--kmeans_init", type=bool, default=False, help="use kmeans_init or not") + parser.add_argument("--kmeans_iters", type=int, default=100, help="max kmeans iters") + parser.add_argument('--sk_epsilons', type=float, nargs='+', default=[0.0, 0.0, 0.0, 0.0], help="sinkhorn epsilons") + parser.add_argument("--sk_iters", type=int, default=50, help="max sinkhorn iters") + + parser.add_argument("--device", type=str, default="cuda:1", help="gpu or cpu") + + parser.add_argument('--num_emb_list', type=int, nargs='+', default=[256,256,256,256], help='emb num of every vq') + parser.add_argument('--e_dim', type=int, default=32, help='vq codebook embedding size') + parser.add_argument('--quant_loss_weight', type=float, default=1.0, help='vq quantion loss weight') + parser.add_argument('--layers', type=int, nargs='+', default=[2048,1024,512,256,128,64], help='hidden sizes of every layer') + + parser.add_argument("--ckpt_path", type=str, default="", help="output directory for model") + parser.add_argument("--warmup", type=int, default=5, help="epochs for warmup") + parser.add_argument("--item_model", type=str, default="", help="") + parser.add_argument("--user_model", type=str, default="", help="") + + return parser + +def parse_test_args(parser): + + parser.add_argument("--ckpt_path", type=str, + default="", + help="The checkpoint path") + parser.add_argument("--lora", action="store_true", default=False) + parser.add_argument("--filter_items", action="store_true", default=False, + help="whether filter illegal items") + + parser.add_argument("--results_file", type=str, + default="./results/test-ddp.json", + help="result output path") + + parser.add_argument("--test_batch_size", type=int, default=1) + parser.add_argument("--num_beams", type=int, default=20) + parser.add_argument("--sample_num", type=int, default=-1, + help="test sample number, -1 represents using all test data") + parser.add_argument("--gpu_id", type=int, default=0, + help="GPU ID when testing with single GPU") + parser.add_argument("--test_prompt_ids", type=str, default="0", + help="test prompt ids, separate by comma. 'all' represents using all") + parser.add_argument("--metrics", type=str, default="hit@1,hit@5,hit@10,ndcg@5,ndcg@10", + help="test metrics, separate by comma") + parser.add_argument("--test_task", type=str, default="SeqRec") + + + return parser + +def get_local_time(): + cur = datetime.datetime.now() + cur = cur.strftime("%b-%d-%Y_%H-%M-%S") + + return cur + + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.enabled = False + +def ensure_dir(dir_path): + + os.makedirs(dir_path, exist_ok=True) + + +def load_datasets(args): + + tasks = args.tasks.split(",") + + train_prompt_sample_num = [int(_) for _ in args.train_prompt_sample_num.split(",")] + assert len(tasks) == len(train_prompt_sample_num), "prompt sample number does not match task number" + train_data_sample_num = [int(_) for _ in args.train_data_sample_num.split(",")] + assert len(tasks) == len(train_data_sample_num), "data sample number does not match task number" + + train_datasets = [] + for task, prompt_sample_num,data_sample_num in zip(tasks,train_prompt_sample_num,train_data_sample_num): + if task.lower() == "seqrec": + dataset = SeqRecDataset(args, mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() == "item2index" or task.lower() == "index2item": + dataset = ItemFeatDataset(args, task=task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() in ["inters2title", "inters2description", "intertitles2item"]: + dataset = FusionSeqRecDataset(args, task=task.lower(), mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() in ["itemsearch", "query2item"]: + dataset = ItemSearchDataset(args, task=task.lower(),mode="train", prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() == "preferenceobtain": + dataset = PreferenceObtainDataset(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() == 'usersearch': + dataset = UserSearchDataset(args, prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + elif task.lower() in ["pref2user", "user2pref"]: + dataset = UserFeatDataset(args, task = task.lower(), prompt_sample_num=prompt_sample_num, sample_num=data_sample_num) + + else: + raise NotImplementedError + train_datasets.append(dataset) + + train_data = ConcatDataset(train_datasets) + + valid_data = SeqRecDataset(args,"valid",args.valid_prompt_sample_num) + + return train_data, valid_data + +def load_test_dataset(args): + + if args.test_task.lower() == "seqrec": + test_data = SeqRecFinetune(args, mode="test", sample_num=args.sample_num) + elif args.test_task.lower() == "itemsearch": + test_data = ItemSearchDataset(args, mode="test", sample_num=args.sample_num) + elif args.test_task.lower() == "fusionseqrec": + test_data = FusionSeqRecDataset(args, mode="test", sample_num=args.sample_num) + else: + raise NotImplementedError + + return test_data + +# def load_test_dataset(args): + +# if args.test_task.lower() == "seqrec": +# test_data = SeqRecDataset(args, mode="test", sample_num=args.sample_num) +# elif args.test_task.lower() == "itemsearch": +# test_data = ItemSearchDataset(args, mode="test", sample_num=args.sample_num) +# elif args.test_task.lower() == "fusionseqrec": +# test_data = FusionSeqRecDataset(args, mode="test", sample_num=args.sample_num) +# else: +# raise NotImplementedError + +# return test_data + +def load_json(file): + with open(file, 'r') as f: + data = json.load(f) + return data