Junyin commited on
Commit
936459a
·
verified ·
1 Parent(s): 0ee60a9

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ asset/model.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LC-Rec
2
+
3
+ This is the official PyTorch implementation for the paper:
4
+
5
+ > [Adapting Large Language Models by Integrating Collaborative Semantics for Recommendation](https://arxiv.org/abs/2311.09049)
6
+
7
+ ## Overview
8
+
9
+ We propose **LC-Rec**, a new approach to integrate **L**anguage and **C**ollaborative semantics for improving LLMs in **Rec**ommender systems. To tackle the large gap between the language semantics modeled by LLMs and collaborative semantics implied by recommender systems, we make two major contributions in two aspects. For item indexing, we design a learning-based vector quantization method with uniform semantic mapping, which can assign meaningful and non-conflicting IDs (called item indices) for items. For alignment tuning, we propose a series of specially designed tuning tasks to enhance the integration of collaborative semantics in LLMs. Our fine-tuning tasks enforce LLMs to deeply integrate language and collaborative semantics (characterized by the learned item indices), so as to achieve an effective adaptation to recommender systems.
10
+
11
+ ![model](./asset/model.jpg)
12
+
13
+ ## Requirements
14
+
15
+ ```
16
+ torch==1.13.1+cu117
17
+ accelerate
18
+ bitsandbytes
19
+ deepspeed
20
+ evaluate
21
+ peft
22
+ sentencepiece
23
+ tqdm
24
+ transformers
25
+ ```
26
+
27
+ ## Model Checkpoint
28
+
29
+ The delta weights on the three datasets can be downloaded from huggingface hub ([Instruments](https://huggingface.co/bwzheng0324/lc-rec-instruments-delta), [Arts](https://huggingface.co/bwzheng0324/lc-rec-arts-delta), [Games](https://huggingface.co/bwzheng0324/lc-rec-games-delta)). After downloading, you can add our deltas to the original LLaMA weights to obtain LC-Rec weights:
30
+
31
+ 1. Get the original [LLaMA](https://huggingface.co/huggyllama/llama-7b) weights.
32
+ 2. Use the following scripts to get LC-Rec weights by applying our delta.
33
+
34
+ ```shell
35
+ python -m convert/merge_delta.py \
36
+ --base-model-path /path/to/llama-7b \
37
+ --target-model-path /path/output/lc-rec \
38
+ --delta-path bwzheng0324/lc-rec-games-delta
39
+ ```
40
+
41
+ ## Dataset
42
+
43
+ We use three datasets in our paper, all of which have been uploaded to [Google Drive](https://drive.google.com/drive/folders/1RcJ2M1l5zWPHYuGd9l5Gibcs5w5aI3y6?usp=sharing)
44
+
45
+ ## Train
46
+
47
+ The detailed scripts for all three datasets are in `run.sh`:
48
+
49
+ ```shell
50
+ DATASET=Games
51
+ BASE_MODEL=huggyllama/llama-7b
52
+ DATA_PATH=./data
53
+ OUTPUT_DIR=./ckpt/$DATASET/
54
+
55
+ torchrun --nproc_per_node=8 --master_port=3324 finetune.py \
56
+ --base_model $BASE_MODEL \
57
+ --output_dir $OUTPUT_DIR \
58
+ --dataset $DATASET \
59
+ --data_path $DATA_PATH \
60
+ --per_device_batch_size 8 \
61
+ --gradient_accumulation_steps 2 \
62
+ --learning_rate 5e-5 \
63
+ --epochs 4 \
64
+ --weight_decay 0.01 \
65
+ --save_and_eval_strategy epoch \
66
+ --deepspeed ./config/ds_z3_bf16.json \
67
+ --bf16 \
68
+ --only_train_response \
69
+ --tasks seqrec,item2index,index2item,fusionseqrec,itemsearch,preferenceobtain \
70
+ --train_prompt_sample_num 1,1,1,1,1,1 \
71
+ --train_data_sample_num 0,0,0,100000,0,0 \
72
+ --index_file .index.json
73
+
74
+
75
+ cd convert
76
+ nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
77
+ cd ..
78
+ ```
79
+
80
+ ## Test
81
+
82
+ Test with a single GPU:
83
+
84
+ ```shell
85
+ DATASET=Games
86
+ DATA_PATH=./data
87
+ OUTPUT_DIR=./ckpt/$DATASET/
88
+ RESULTS_FILE=./results/$DATASET/result.json
89
+
90
+ python test.py \
91
+ --gpu_id 0 \
92
+ --ckpt_path $CKPT_PATH \
93
+ --dataset $DATASET \
94
+ --data_path $DATA_PATH \
95
+ --results_file $RESULTS_FILE \
96
+ --test_batch_size 1 \
97
+ --num_beams 20 \
98
+ --test_prompt_ids all \
99
+ --index_file .index.json
100
+ ```
101
+
102
+ Test with multiple GPUs:
103
+
104
+ ```shell
105
+ DATASET=Games
106
+ DATA_PATH=./data
107
+ OUTPUT_DIR=./ckpt/$DATASET/
108
+ RESULTS_FILE=./results/$DATASET/result.json
109
+
110
+ torchrun --nproc_per_node=8 --master_port=4324 test_ddp.py \
111
+ --ckpt_path $CKPT_PATH \
112
+ --dataset $DATASET \
113
+ --data_path $DATA_PATH \
114
+ --results_file $RESULTS_FILE \
115
+ --test_batch_size 1 \
116
+ --num_beams 20 \
117
+ --test_prompt_ids all \
118
+ --index_file .index.json
119
+ ```
120
+
121
+ ## Acknowledgement
122
+
123
+ The implementation is based on [HuggingFace](https://github.com/huggingface/transformers).
124
+
asset/model.jpg ADDED

Git LFS Details

  • SHA256: 52223d0ef7f3701a6e40db9997e78c0a7f0d6bfce7965b9f27637e0e25fd1097
  • Pointer size: 132 Bytes
  • Size of remote file: 1.13 MB
collator.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import copy
3
+ import argparse
4
+ from dataclasses import dataclass
5
+
6
+ import transformers
7
+ import math
8
+ from torch.utils.data import Sampler
9
+ import torch.distributed as dist
10
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, T5Tokenizer, T5Config, T5ForConditionalGeneration
11
+
12
+ class VanillaCollator(object):
13
+ def __init__(self, args, tokenizer):
14
+ self.args = args
15
+ self.tokenizer = tokenizer
16
+ def __call__(self, data):
17
+ # print('collator data:',data)
18
+ '''
19
+ [{
20
+ 'input_ids':
21
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n
22
+ ### Instruction:\n
23
+ Access the user's historical item interaction records: {inters}.
24
+ Your objective is to describe the next potential item for him, taking into account his past interactions.\n\n
25
+ ### Response:",
26
+ 'labels':
27
+ "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n
28
+ ### Instruction:\n
29
+ Access the user's historical item interaction records: {inters}.
30
+ Your objective is to describe the next potential item for him, taking into account his past interactions.\n\n
31
+ ### Response:
32
+ Dunlop guitar picks are a top choice of today's pro musician! Dunlop's wide variety of gauges, shapes, sizes and materials
33
+ allows the player to select the exact pick for his/her own particular style of playing. From classic country to nu-metal,
34
+ every great player knows that their pick is an integral part of their tone, and Dunlop guitar picks are the picks that more
35
+ pros rely on in the studio or on stage. Picks are a grossly underrated accessory. Don't sacrifice your tone...pick Dunlop guitar picks!.",
36
+ 'inters': '341,2804,3895,3893,7064',
37
+ 'item': 'placeholder',
38
+ 'task': 'inters2description'
39
+ },
40
+ {
41
+ 'input_ids':
42
+ 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n
43
+ ### Instruction:\n
44
+ Based on the user\'s historical interactions with the following items: {inters}.
45
+ You can infer his preference by observing the historical interactions: "The user\'s short-term preferences have shift to heavier picks,
46
+ suggesting that He is looking for a heavier sound.". Now the user wants a new item and searches for: "I like the durability and
47
+ effectiveness of the picks.". Please select a suitable item that matches his preference and search intent.\n\n
48
+ ### Response:',
49
+ 'labels':
50
+ 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n
51
+ ### Instruction:\n
52
+ Based on the user\'s historical interactions with the following items: {inters}.
53
+ You can infer his preference by observing the historical interactions: "The user\'s short-term preferences have shift to heavier picks,
54
+ suggesting that He is looking for a heavier sound.". Now the user wants a new item and searches for: "I like the durability and
55
+ effectiveness of the picks.". Please select a suitable item that matches his preference and search intent.\n\n
56
+ ### Response:{item}',
57
+ 'inters': '122,469,8918',
58
+ 'item': '7140',
59
+ 'task': 'itemsearch'
60
+ }]
61
+ '''
62
+ dict_data = {
63
+ 'input_ids': [],
64
+ 'labels': [],
65
+ 'inters': [],
66
+ 'item': [],
67
+ 'users': [],
68
+ 'user': [],
69
+ 'task': []
70
+ }
71
+
72
+ for d in data:
73
+ for k in dict_data.keys():
74
+ if k == 'labels':
75
+ dict_data[k].append(d[k] + self.tokenizer.eos_token)
76
+ else:
77
+ dict_data[k].append(d[k])
78
+
79
+ return dict_data
80
+
81
+ class Collator(object):
82
+
83
+ def __init__(self, args, tokenizer):
84
+ self.args = args
85
+ self.only_train_response = args.only_train_response
86
+ self.tokenizer = tokenizer
87
+ if self.tokenizer.pad_token_id is None:
88
+ self.tokenizer.pad_token_id = self.tokenizer.unk_token_id
89
+ # print(self.tokenizer.model_max_length)
90
+
91
+ def __call__(self, batch):
92
+
93
+ input_texts = [d["input_ids"] for d in batch]
94
+ full_texts = [d["labels"] + self.tokenizer.eos_token for d in batch]
95
+
96
+ inputs = self.tokenizer(
97
+ text = full_texts,
98
+ text_target = input_texts,
99
+ return_tensors="pt",
100
+ padding="longest",
101
+ max_length=self.tokenizer.model_max_length,
102
+ truncation=True,
103
+ return_attention_mask=True,
104
+ )
105
+ labels = copy.deepcopy(inputs["input_ids"])
106
+ if self.only_train_response:
107
+ # ignore padding
108
+ labels[labels == self.tokenizer.pad_token_id] = -100
109
+ # ignore input text
110
+ labels[torch.where(inputs["labels"] != self.tokenizer.pad_token_id)] = -100
111
+
112
+ inputs["labels"] = labels
113
+
114
+ return inputs
115
+
116
+ class TestCollator(object):
117
+ def __init__(self, args, tokenizer):
118
+ self.args = args
119
+ self.tokenizer = tokenizer
120
+ if self.tokenizer.pad_token_id is None:
121
+ self.tokenizer.pad_token_id = 0
122
+
123
+ if isinstance(self.tokenizer, LlamaTokenizer):
124
+ self.tokenizer.padding_side = "left"
125
+
126
+ def __call__(self, batch):
127
+ input_texts = [d["input_ids"] for d in batch]
128
+ targets = [d["labels"] for d in batch]
129
+ inputs = self.tokenizer(
130
+ text = input_texts,
131
+ return_tensors ="pt",
132
+ padding = "longest",
133
+ max_length = self.tokenizer.model_max_length,
134
+ truncation = True,
135
+ return_attention_mask = True,
136
+ )
137
+
138
+ return (inputs, targets)
139
+
140
+ # RuntimeError: Cannot re-initialize CUDA in forked subprocess.
141
+ # To use CUDA with multiprocessing, you must use the 'spawn' start method.
142
+ # class ValidCollator(object):
143
+ # def __init__(self, args, model):
144
+ # self.args = args
145
+ # self.model = model
146
+ # self.only_train_response = args.only_train_response
147
+ # self.tokenizer = model.tokenizer
148
+ # def __call__(self, data):
149
+ # llama_model = self.model.model.get_decoder()
150
+ # for d in data:
151
+ # inter_emb_list = []
152
+ # inter_item_list = d['inters'].split(',')
153
+ # for inter_item in inter_item_list:
154
+ # inter_feature = self.model.item_texts[inter_item]['title'] + ' ' + self.model.item_texts[inter_item]['description']
155
+ # inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
156
+ # inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask)
157
+ # inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1)
158
+ # inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True)
159
+ # inter_emb_list.append(inter_emb.detach())
160
+ # inter_embs = torch.cat(inter_emb_list, dim = 0)
161
+ # item_feature = self.model.item_texts[d['item']]['title'] + ' ' + self.model.item_texts[d['item']]['description']
162
+ # item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
163
+ # item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask)
164
+ # item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1)
165
+ # item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True)
166
+ # item_emb = item_emb.detach()
167
+
168
+ # rqids = self.model.rqvae.get_indices(torch.cat([inter_embs, item_emb], dim = 0))
169
+
170
+ # inters_rqids = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[:-1]
171
+ # item_rqid = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[-1]
172
+
173
+ # text_rqids = {}
174
+ # code = ''
175
+ # for rqid in inters_rqids:
176
+ # for k, idx in enumerate(rqid):
177
+ # code = code + self.model.prefix[k].format(idx)
178
+ # code = code + ', '
179
+ # text_rqids['inters'] = code[:-2]
180
+ # code = ''
181
+ # for k, idx in enumerate(item_rqid):
182
+ # code = code + self.model.prefix[k].format(idx)
183
+ # text_rqids['item'] = code
184
+
185
+ # d['input_ids'] = d['input_ids'].format(inters = text_rqids['inters'])
186
+ # d['labels'] = d['labels'].format(inters = text_rqids['inters'], item = text_rqids['item'])
187
+
188
+ # input_texts = [d["input_ids"] for d in data]
189
+ # full_texts = [d["labels"] + self.tokenizer.eos_token for d in data]
190
+
191
+ # inputs = self.tokenizer(
192
+ # text = full_texts,
193
+ # text_target = input_texts,
194
+ # return_tensors="pt",
195
+ # padding="longest",
196
+ # max_length=self.tokenizer.model_max_length,
197
+ # truncation=True,
198
+ # return_attention_mask=True,
199
+ # )
200
+
201
+ # labels = copy.deepcopy(inputs["input_ids"])
202
+ # if self.only_train_response:
203
+ # labels[labels == self.tokenizer.pad_token_id] = -100
204
+ # labels[torch.where(inputs["labels"] != self.tokenizer.pad_token_id)] = -100
205
+ # inputs["labels"] = labels
206
+
207
+ # return inputs
208
+
209
+ # RuntimeError: Cannot re-initialize CUDA in forked subprocess.
210
+ # To use CUDA with multiprocessing, you must use the 'spawn' start method.
211
+ # class TestCollator(object):
212
+ # def __init__(self, args, model):
213
+ # self.args = args
214
+ # self.model = model
215
+ # self.tokenizer = model.tokenizer
216
+ # if self.tokenizer.pad_token_id is None:
217
+ # self.tokenizer.pad_token_id = 0
218
+ # if isinstance(self.tokenizer, LlamaTokenizer):
219
+ # self.tokenizer.padding_side = "left"
220
+
221
+ # def __call__(self, data):
222
+ # llama_model = self.model.model.get_decoder()
223
+ # for d in data:
224
+ # inter_emb_list = []
225
+ # inter_item_list = d['inters'].split(',')
226
+ # for inter_item in inter_item_list:
227
+ # inter_feature = self.model.item_texts[inter_item]['title'] + ' ' + self.model.item_texts[inter_item]['description']
228
+ # inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
229
+ # inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask)
230
+ # inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1)
231
+ # inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True)
232
+ # inter_emb_list.append(inter_emb.detach())
233
+ # inter_embs = torch.cat(inter_emb_list, dim = 0)
234
+ # item_feature = self.model.item_texts[d['item']]['title'] + ' ' + self.model.item_texts[d['item']]['description']
235
+ # item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
236
+ # item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask)
237
+ # item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1)
238
+ # item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True)
239
+ # item_emb = item_emb.detach()
240
+
241
+ # rqids = self.model.rqvae.get_indices(torch.cat([inter_embs, item_emb], dim = 0))
242
+
243
+ # inters_rqids = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[:-1]
244
+ # item_rqid = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[-1]
245
+
246
+ # text_rqids = {}
247
+ # code = ''
248
+ # for rqid in inters_rqids:
249
+ # for k, idx in enumerate(rqid):
250
+ # code = code + self.model.prefix[k].format(idx)
251
+ # code = code + ', '
252
+ # text_rqids['inters'] = code[:-2]
253
+ # code = ''
254
+ # for k, idx in enumerate(item_rqid):
255
+ # code = code + self.model.prefix[k].format(idx)
256
+ # text_rqids['item'] = code
257
+
258
+ # d['input_ids'] = d['input_ids'].format(inters = text_rqids['inters'])
259
+ # d['labels'] = d['labels'].format(inters = text_rqids['inters'], item = text_rqids['item'])
260
+
261
+ # input_texts = [d["input_ids"] for d in data]
262
+ # targets = [d["labels"] for d in data]
263
+
264
+ # inputs = self.tokenizer(
265
+ # text=input_texts,
266
+ # return_tensors="pt",
267
+ # padding="longest",
268
+ # max_length=self.tokenizer.model_max_length,
269
+ # truncation=True,
270
+ # return_attention_mask=True,
271
+ # )
272
+
273
+ # return (inputs, targets)
config/ds_z2_bf16.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": "auto"
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 2,
7
+ "allgather_partitions": true,
8
+ "allgather_bucket_size": 5e8,
9
+ "overlap_comm": true,
10
+ "reduce_scatter": true,
11
+ "reduce_bucket_size": 5e8,
12
+ "contiguous_gradients": true
13
+ },
14
+ "gradient_accumulation_steps": "auto",
15
+ "gradient_clipping": "auto",
16
+ "steps_per_print": 2000,
17
+ "train_batch_size": "auto",
18
+ "train_micro_batch_size_per_gpu": "auto",
19
+ "wall_clock_breakdown": false,
20
+ "flops_profiler": {
21
+ "enabled": true,
22
+ "profile_step": 10,
23
+ "module_depth": -1,
24
+ "top_modules": 3,
25
+ "detailed": true,
26
+ "output_file": "flops_profiler.out"
27
+ }
28
+ }
config/ds_z2_fp16.json ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "auto_cast": false,
5
+ "loss_scale": 0,
6
+ "initial_scale_power": 16,
7
+ "loss_scale_window": 1000,
8
+ "hysteresis": 2,
9
+ "min_loss_scale": 1
10
+ },
11
+ "zero_optimization": {
12
+ "stage": 2,
13
+ "allgather_partitions": true,
14
+ "allgather_bucket_size": 5e8,
15
+ "overlap_comm": true,
16
+ "reduce_scatter": true,
17
+ "reduce_bucket_size": 5e8,
18
+ "contiguous_gradients": true
19
+ },
20
+ "gradient_accumulation_steps": "auto",
21
+ "gradient_clipping": "auto",
22
+ "steps_per_print": 2000,
23
+ "train_batch_size": "auto",
24
+ "train_micro_batch_size_per_gpu": "auto",
25
+ "wall_clock_breakdown": false,
26
+ "flops_profiler": {
27
+ "enabled": true,
28
+ "profile_step": 10,
29
+ "module_depth": -1,
30
+ "top_modules": 3,
31
+ "detailed": true,
32
+ "output_file": "flops_profiler.out"
33
+ }
34
+ }
config/ds_z3_bf16.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": "auto"
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 3,
7
+ "overlap_comm": true,
8
+ "contiguous_gradients": true,
9
+ "sub_group_size": 1e9,
10
+ "reduce_bucket_size": "auto",
11
+ "stage3_prefetch_bucket_size": "auto",
12
+ "stage3_param_persistence_threshold": "auto",
13
+ "stage3_max_live_parameters": 1e9,
14
+ "stage3_max_reuse_distance": 1e9,
15
+ "stage3_gather_16bit_weights_on_model_save": false
16
+ },
17
+ "gradient_accumulation_steps": "auto",
18
+ "gradient_clipping": "auto",
19
+ "steps_per_print": 2000,
20
+ "train_batch_size": "auto",
21
+ "train_micro_batch_size_per_gpu": "auto",
22
+ "wall_clock_breakdown": false,
23
+ "flops_profiler": {
24
+ "enabled": true,
25
+ "profile_step": 10,
26
+ "module_depth": -1,
27
+ "top_modules": 3,
28
+ "detailed": true,
29
+ "output_file": "flops_profiler.out"
30
+ }
31
+ }
config/ds_z3_bf16_save16bit.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": "auto"
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 3,
7
+ "overlap_comm": true,
8
+ "contiguous_gradients": true,
9
+ "sub_group_size": 1e9,
10
+ "reduce_bucket_size": "auto",
11
+ "stage3_prefetch_bucket_size": "auto",
12
+ "stage3_param_persistence_threshold": "auto",
13
+ "stage3_max_live_parameters": 1e9,
14
+ "stage3_max_reuse_distance": 1e9,
15
+ "stage3_gather_16bit_weights_on_model_save": true
16
+ },
17
+ "gradient_accumulation_steps": "auto",
18
+ "gradient_clipping": "auto",
19
+ "steps_per_print": 2000,
20
+ "train_batch_size": "auto",
21
+ "train_micro_batch_size_per_gpu": "auto",
22
+ "wall_clock_breakdown": false,
23
+ "flops_profiler": {
24
+ "enabled": true,
25
+ "profile_step": 10,
26
+ "module_depth": -1,
27
+ "top_modules": 3,
28
+ "detailed": true,
29
+ "output_file": "flops_profiler.out"
30
+ }
31
+ }
config/ds_z3_fp16.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "auto_cast": false,
5
+ "loss_scale": 0,
6
+ "initial_scale_power": 16,
7
+ "loss_scale_window": 1000,
8
+ "hysteresis": 2,
9
+ "min_loss_scale": 1
10
+ },
11
+ "zero_optimization": {
12
+ "stage": 3,
13
+ "overlap_comm": true,
14
+ "contiguous_gradients": true,
15
+ "sub_group_size": 1e9,
16
+ "reduce_bucket_size": "auto",
17
+ "stage3_prefetch_bucket_size": "auto",
18
+ "stage3_param_persistence_threshold": "auto",
19
+ "stage3_max_live_parameters": 1e9,
20
+ "stage3_max_reuse_distance": 1e9,
21
+ "stage3_gather_16bit_weights_on_model_save": false
22
+ },
23
+ "gradient_accumulation_steps": "auto",
24
+ "gradient_clipping": "auto",
25
+ "steps_per_print": 2000,
26
+ "train_batch_size": "auto",
27
+ "train_micro_batch_size_per_gpu": "auto",
28
+ "wall_clock_breakdown": false,
29
+ "flops_profiler": {
30
+ "enabled": true,
31
+ "profile_step": 10,
32
+ "module_depth": -1,
33
+ "top_modules": 3,
34
+ "detailed": true,
35
+ "output_file": "flops_profiler.out"
36
+ }
37
+ }
config/ds_z3_fp16_save16bit.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "fp16": {
3
+ "enabled": "auto",
4
+ "auto_cast": false,
5
+ "loss_scale": 0,
6
+ "initial_scale_power": 16,
7
+ "loss_scale_window": 1000,
8
+ "hysteresis": 2,
9
+ "min_loss_scale": 1
10
+ },
11
+ "zero_optimization": {
12
+ "stage": 3,
13
+ "overlap_comm": true,
14
+ "contiguous_gradients": true,
15
+ "sub_group_size": 1e9,
16
+ "reduce_bucket_size": "auto",
17
+ "stage3_prefetch_bucket_size": "auto",
18
+ "stage3_param_persistence_threshold": "auto",
19
+ "stage3_max_live_parameters": 1e9,
20
+ "stage3_max_reuse_distance": 1e9,
21
+ "stage3_gather_16bit_weights_on_model_save": true
22
+ },
23
+ "gradient_accumulation_steps": "auto",
24
+ "gradient_clipping": "auto",
25
+ "steps_per_print": 2000,
26
+ "train_batch_size": "auto",
27
+ "train_micro_batch_size_per_gpu": "auto",
28
+ "wall_clock_breakdown": false,
29
+ "flops_profiler": {
30
+ "enabled": true,
31
+ "profile_step": 10,
32
+ "module_depth": -1,
33
+ "top_modules": 3,
34
+ "detailed": true,
35
+ "output_file": "flops_profiler.out"
36
+ }
37
+ }
continue_pretrain.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import List
4
+ import argparse
5
+
6
+ import wandb
7
+ import torch
8
+ import transformers
9
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
10
+
11
+ from peft import (
12
+ TaskType,
13
+ LoraConfig,
14
+ get_peft_model,
15
+ get_peft_model_state_dict,
16
+ set_peft_model_state_dict,
17
+ )
18
+
19
+ from collator import VanillaCollator
20
+ from rq_llama import *
21
+ from utils import *
22
+
23
+ parser = argparse.ArgumentParser(description = 'rqllama-pretrain')
24
+ parser = parse_global_args(parser)
25
+ parser = parse_train_args(parser)
26
+ parser = parse_dataset_args(parser)
27
+ parser = parse_rqvae_args(parser)
28
+ args = parser.parse_args()
29
+ wandb.init(config = args, reinit = True)
30
+
31
+ set_seed(args.seed)
32
+ ensure_dir(args.output_dir)
33
+
34
+ device_map = "auto"
35
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
36
+ ddp = world_size != 1
37
+ local_rank = int(os.environ.get("LOCAL_RANK") or 0)
38
+ if local_rank == 0:
39
+ print(vars(args))
40
+ if ddp:
41
+ device_map = {"": local_rank}
42
+
43
+ train_data, valid_data = load_datasets(args)
44
+
45
+ rqllama = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map)
46
+ for i in range(len(args.num_emb_list)):
47
+ rqllama.item_rqvae.rq.vq_layers[i].initted = True
48
+ rqllama.user_rqvae.rq.vq_layers[i].initted = True
49
+
50
+ if local_rank == 0:
51
+ print("token num:", len(rqllama.tokenizer))
52
+ print("data num:", len(train_data))
53
+ rqllama.tokenizer.save_pretrained(args.output_dir)
54
+ rqllama.config.save_pretrained(args.output_dir)
55
+
56
+ if args.resume_from_checkpoint:
57
+ checkpoint_name = os.path.join(args.resume_from_checkpoint, "adapter_model.bin")
58
+ args.resume_from_checkpoint = False
59
+ if os.path.exists(checkpoint_name):
60
+ if local_rank == 0:
61
+ print(f"Restarting from {checkpoint_name}")
62
+ adapters_weights = torch.load(checkpoint_name)
63
+ rqllama.model = set_peft_model_state_dict(rqllama.model, adapters_weights)
64
+ else:
65
+ if local_rank == 0:
66
+ print(f"Checkpoint {checkpoint_name} not found")
67
+
68
+ if local_rank == 0:
69
+ rqllama.model.print_trainable_parameters()
70
+
71
+ if not ddp and torch.cuda.device_count() > 1:
72
+ rqllama.is_parallelizable = True
73
+ rqllama.model_parallel = True
74
+
75
+ collator = VanillaCollator(args, rqllama.tokenizer)
76
+
77
+ trainer = transformers.Trainer(
78
+ model = rqllama,
79
+ train_dataset = train_data,
80
+ eval_dataset = valid_data,
81
+ args = transformers.TrainingArguments(
82
+ seed = args.seed,
83
+ per_device_train_batch_size = args.per_device_batch_size,
84
+ per_device_eval_batch_size = args.per_device_batch_size,
85
+ gradient_accumulation_steps = args.gradient_accumulation_steps,
86
+ warmup_ratio = args.warmup_ratio,
87
+ num_train_epochs = args.epochs,
88
+ learning_rate = args.learning_rate,
89
+ weight_decay = args.weight_decay,
90
+ lr_scheduler_type = args.lr_scheduler_type,
91
+ fp16 = args.fp16,
92
+ bf16 = args.bf16,
93
+ logging_steps = args.logging_step,
94
+ optim = args.optim,
95
+ gradient_checkpointing = True,
96
+ evaluation_strategy = args.save_and_eval_strategy,
97
+ save_strategy = args.save_and_eval_strategy,
98
+ eval_steps = args.save_and_eval_steps,
99
+ save_steps = args.save_and_eval_steps,
100
+ output_dir = args.output_dir,
101
+ save_total_limit = 5,
102
+ load_best_model_at_end = True,
103
+ deepspeed = args.deepspeed,
104
+ ddp_find_unused_parameters = False if ddp else None,
105
+ report_to = None,
106
+ eval_delay = 1 if args.save_and_eval_strategy=="epoch" else 2000,
107
+ dataloader_num_workers = args.dataloader_num_workers,
108
+ dataloader_prefetch_factor = args.dataloader_prefetch_factor,
109
+ remove_unused_columns = args.remove_unused_columns,
110
+ ),
111
+ tokenizer = rqllama.tokenizer,
112
+ data_collator = collator,
113
+ )
114
+ rqllama.config.use_cache = False
115
+
116
+ if torch.__version__ >= "2" and sys.platform != "win32":
117
+ rqllama = torch.compile(rqllama)
118
+
119
+ trainer.train(resume_from_checkpoint = args.resume_from_checkpoint)
120
+
121
+ trainer.save_state()
122
+ trainer.save_model(output_dir = args.output_dir)
123
+
124
+ if local_rank == 0:
125
+ print('rqllama pre-train finished.')
convert/convert.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ import argparse
3
+ import os
4
+
5
+ if __name__ == '__main__':
6
+ parser = argparse.ArgumentParser()
7
+ parser.add_argument("--source", "-s", type=str, default="", help="source path of models")
8
+ parser.add_argument("--target", "-t", type=str, default="", help="target path of models")
9
+
10
+ args, _ = parser.parse_known_args()
11
+
12
+ assert os.path.exists(args.source)
13
+ assert args.target != ""
14
+
15
+ model = transformers.AutoModelForCausalLM.from_pretrained(args.source)
16
+ model.save_pretrained(args.target, state_dict=model.state_dict())
convert/convert.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model=$1
2
+
3
+ set -x
4
+
5
+ for step in `ls ${model} | grep checkpoint | awk -F'-' '{ print $2 }'`
6
+ do
7
+ mkdir ${model}/tmp-checkpoint-${step}
8
+ mkdir ${model}/final-checkpoint-${step}
9
+ python ./zero_to_fp32.py ${model}/checkpoint-${step}/ ${model}/tmp-checkpoint-${step}/pytorch_model.bin
10
+ cp ${model}/*.json ${model}/tmp-checkpoint-${step}
11
+ python ./convert.py -s ${model}/tmp-checkpoint-${step} -t ${model}/final-checkpoint-${step}
12
+ cp ${model}/checkpoint-${step}/*.json ${model}/final-checkpoint-${step}
13
+ cp ${model}/*.json ${model}/final-checkpoint-${step}
14
+ cp ${model}/tokenizer* ${model}/final-checkpoint-${step}
15
+ cp ${model}/train* ${model}/final-checkpoint-${step}
16
+ #rm -rf ${model}/tmp-checkpoint-${step} ${model}/checkpoint-${step} ${model}/global_step${step}
17
+ #mv ${model}/final-checkpoint-${step} ${model}/checkpoint-${step}
18
+ done
convert/convert_fp16.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ import torch
6
+
7
+
8
+ def convert_fp16(in_checkpoint, out_checkpoint):
9
+ tokenizer = AutoTokenizer.from_pretrained(in_checkpoint, use_fast=False)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ in_checkpoint, torch_dtype=torch.float16, low_cpu_mem_usage=True
12
+ )
13
+ model.save_pretrained(out_checkpoint)
14
+ tokenizer.save_pretrained(out_checkpoint)
15
+
16
+
17
+ if __name__ == "__main__":
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--in-checkpoint", type=str, help="Path to the model")
20
+ parser.add_argument("--out-checkpoint", type=str, help="Path to the output model")
21
+ args = parser.parse_args()
22
+
23
+ convert_fp16(args.in_checkpoint, args.out_checkpoint)
convert/make_delta.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+
4
+ import torch
5
+ from tqdm import tqdm
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM
7
+
8
+
9
+ def make_delta(base_model_path, target_model_path, delta_path):
10
+ print(f"Loading the base model from {base_model_path}")
11
+ base = AutoModelForCausalLM.from_pretrained(
12
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
13
+ )
14
+
15
+ print(f"Loading the target model from {target_model_path}")
16
+ target = AutoModelForCausalLM.from_pretrained(
17
+ target_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
18
+ )
19
+ target_tokenizer = AutoTokenizer.from_pretrained(target_model_path, use_fast=False)
20
+
21
+ print("Calculating the delta")
22
+ for name, param in tqdm(target.state_dict().items(), desc="Calculating delta"):
23
+ assert name in base.state_dict()
24
+ if param.shape == base.state_dict()[name].shape:
25
+ param.data -= base.state_dict()[name]
26
+ else:
27
+ print(name)
28
+
29
+ print(f"Saving the delta to {delta_path}")
30
+ if args.hub_repo_id:
31
+ kwargs = {"push_to_hub": True, "repo_id": args.hub_repo_id}
32
+ else:
33
+ kwargs = {}
34
+ target.save_pretrained(delta_path, **kwargs)
35
+ target_tokenizer.save_pretrained(delta_path, **kwargs)
36
+
37
+
38
+ if __name__ == "__main__":
39
+ parser = argparse.ArgumentParser()
40
+ parser.add_argument("--base-model-path", type=str, required=True)
41
+ parser.add_argument("--target-model-path", type=str, required=True)
42
+ parser.add_argument("--delta-path", type=str, required=True)
43
+ parser.add_argument("--hub-repo-id", type=str)
44
+ args = parser.parse_args()
45
+
46
+ make_delta(args.base_model_path, args.target_model_path, args.delta_path)
convert/merge_delta.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import argparse
3
+ import gc
4
+ import glob
5
+ import json
6
+ import os
7
+ import shutil
8
+ import tempfile
9
+
10
+ from huggingface_hub import snapshot_download
11
+ import torch
12
+ from torch import nn
13
+ from tqdm import tqdm
14
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
15
+
16
+
17
+ GB = 1 << 30
18
+
19
+
20
+ def split_files(model_path, tmp_path, split_size):
21
+ if not os.path.exists(model_path):
22
+ model_path = snapshot_download(repo_id=model_path)
23
+ if not os.path.exists(tmp_path):
24
+ os.makedirs(tmp_path)
25
+
26
+ file_pattern = os.path.join(model_path, "pytorch_model-*.bin")
27
+ files = glob.glob(file_pattern)
28
+
29
+ part = 0
30
+ try:
31
+ for file_path in tqdm(files):
32
+ state_dict = torch.load(file_path)
33
+ new_state_dict = {}
34
+
35
+ current_size = 0
36
+ for name, param in state_dict.items():
37
+ param_size = param.numel() * param.element_size()
38
+
39
+ if current_size + param_size > split_size:
40
+ new_file_name = f"pytorch_model-{part}.bin"
41
+ new_file_path = os.path.join(tmp_path, new_file_name)
42
+ torch.save(new_state_dict, new_file_path)
43
+ current_size = 0
44
+ new_state_dict = None
45
+ gc.collect()
46
+ new_state_dict = {}
47
+ part += 1
48
+
49
+ new_state_dict[name] = param
50
+ current_size += param_size
51
+
52
+ new_file_name = f"pytorch_model-{part}.bin"
53
+ new_file_path = os.path.join(tmp_path, new_file_name)
54
+ torch.save(new_state_dict, new_file_path)
55
+ new_state_dict = None
56
+ gc.collect()
57
+ new_state_dict = {}
58
+ part += 1
59
+ except Exception as e:
60
+ print(f"An error occurred during split_files: {e}")
61
+ shutil.rmtree(tmp_path)
62
+ raise
63
+
64
+
65
+ def apply_delta_low_cpu_mem(base_model_path, target_model_path, delta_path):
66
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
67
+ delta_config = AutoConfig.from_pretrained(delta_path)
68
+
69
+ if os.path.exists(target_model_path):
70
+ shutil.rmtree(target_model_path)
71
+ os.makedirs(target_model_path)
72
+
73
+ split_size = 4 * GB
74
+
75
+ with tempfile.TemporaryDirectory() as tmp_base_path, tempfile.TemporaryDirectory() as tmp_delta_path:
76
+ print(f"Split files for the base model to {tmp_base_path}")
77
+ split_files(base_model_path, tmp_base_path, split_size)
78
+ print(f"Split files for the delta weights to {tmp_delta_path}")
79
+ split_files(delta_path, tmp_delta_path, split_size)
80
+
81
+ base_pattern = os.path.join(tmp_base_path, "pytorch_model-*.bin")
82
+ base_files = glob.glob(base_pattern)
83
+ base_state_dict = torch.load(base_files[0])
84
+ delta_pattern = os.path.join(tmp_delta_path, "pytorch_model-*.bin")
85
+ delta_files = glob.glob(delta_pattern)
86
+ # delta_state_dict = torch.load(delta_files[0])
87
+
88
+ print("Applying the delta")
89
+ weight_map = {}
90
+ total_size = 0
91
+
92
+ for i, delta_file in tqdm(enumerate(delta_files)):
93
+ state_dict = torch.load(delta_file)
94
+ file_name = f"pytorch_model-{i}.bin"
95
+ for name, param in state_dict.items():
96
+ if name not in base_state_dict:
97
+ for base_file in base_files:
98
+ base_state_dict = torch.load(base_file)
99
+ gc.collect()
100
+ if name in base_state_dict:
101
+ break
102
+ if state_dict[name].shape == base_state_dict[name].shape:
103
+ state_dict[name] += base_state_dict[name]
104
+ else:
105
+ print(name)
106
+ weight_map[name] = file_name
107
+ total_size += param.numel() * param.element_size()
108
+ gc.collect()
109
+ torch.save(state_dict, os.path.join(target_model_path, file_name))
110
+
111
+ with open(
112
+ os.path.join(target_model_path, "pytorch_model.bin.index.json"), "w"
113
+ ) as f:
114
+ json.dump(
115
+ {"weight_map": weight_map, "metadata": {"total_size": total_size}}, f
116
+ )
117
+
118
+ print(f"Saving the target model to {target_model_path}")
119
+ delta_tokenizer.save_pretrained(target_model_path)
120
+ delta_config.save_pretrained(target_model_path)
121
+
122
+
123
+ def apply_delta(base_model_path, target_model_path, delta_path):
124
+ print(f"Loading the delta weights from {delta_path}")
125
+ delta_tokenizer = AutoTokenizer.from_pretrained(delta_path, use_fast=False)
126
+ delta = AutoModelForCausalLM.from_pretrained(
127
+ delta_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
128
+ )
129
+
130
+ print(f"Loading the base model from {base_model_path}")
131
+ base = AutoModelForCausalLM.from_pretrained(
132
+ base_model_path, torch_dtype=torch.float16, low_cpu_mem_usage=True
133
+ )
134
+
135
+ print("Applying the delta")
136
+ for name, param in tqdm(delta.state_dict().items(), desc="Applying delta"):
137
+ assert name in base.state_dict()
138
+ if param.shape == base.state_dict()[name].shape:
139
+ param.data += base.state_dict()[name]
140
+ else:
141
+ print(name)
142
+
143
+
144
+ print(f"Saving the target model to {target_model_path}")
145
+ delta.save_pretrained(target_model_path)
146
+ delta_tokenizer.save_pretrained(target_model_path)
147
+
148
+
149
+ if __name__ == "__main__":
150
+ parser = argparse.ArgumentParser()
151
+ parser.add_argument("--base-model-path", type=str, required=True)
152
+ parser.add_argument("--target-model-path", type=str, required=True)
153
+ parser.add_argument("--delta-path", type=str, required=True)
154
+ parser.add_argument(
155
+ "--low-cpu-mem",
156
+ action="store_true",
157
+ help="Lower the cpu memory usage. This will split large files and use "
158
+ "disk as swap to reduce the memory usage below 10GB.",
159
+ )
160
+ args = parser.parse_args()
161
+
162
+ if args.low_cpu_mem:
163
+ apply_delta_low_cpu_mem(
164
+ args.base_model_path, args.target_model_path, args.delta_path
165
+ )
166
+ else:
167
+ apply_delta(args.base_model_path, args.target_model_path, args.delta_path)
convert/zero_to_fp32.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Copyright (c) Microsoft Corporation.
4
+ # SPDX-License-Identifier: Apache-2.0
5
+
6
+ # DeepSpeed Team
7
+
8
+ # This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
9
+ # copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
10
+ # the future. Once extracted, the weights don't require DeepSpeed and can be used in any
11
+ # application.
12
+ #
13
+ # example: python zero_to_fp32.py . pytorch_model.bin
14
+
15
+ import argparse
16
+ import torch
17
+ import glob
18
+ import math
19
+ import os
20
+ import re
21
+ from collections import OrderedDict
22
+ from dataclasses import dataclass
23
+ from tqdm import tqdm
24
+
25
+ # while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
26
+ # DeepSpeed data structures it has to be available in the current python environment.
27
+ from deepspeed.utils import logger
28
+ from deepspeed.checkpoint.constants import (DS_VERSION, OPTIMIZER_STATE_DICT, SINGLE_PARTITION_OF_FP32_GROUPS,
29
+ FP32_FLAT_GROUPS, ZERO_STAGE, PARTITION_COUNT, PARAM_SHAPES, BUFFER_NAMES,
30
+ FROZEN_PARAM_SHAPES, FROZEN_PARAM_FRAGMENTS)
31
+
32
+
33
+ @dataclass
34
+ class zero_model_state:
35
+ buffers: dict()
36
+ param_shapes: dict()
37
+ shared_params: list
38
+ ds_version: int
39
+ frozen_param_shapes: dict()
40
+ frozen_param_fragments: dict()
41
+
42
+
43
+ debug = 0
44
+
45
+ # load to cpu
46
+ device = torch.device('cpu')
47
+
48
+
49
+ def atoi(text):
50
+ return int(text) if text.isdigit() else text
51
+
52
+
53
+ def natural_keys(text):
54
+ '''
55
+ alist.sort(key=natural_keys) sorts in human order
56
+ http://nedbatchelder.com/blog/200712/human_sorting.html
57
+ (See Toothy's implementation in the comments)
58
+ '''
59
+ return [atoi(c) for c in re.split(r'(\d+)', text)]
60
+
61
+
62
+ def get_model_state_file(checkpoint_dir, zero_stage):
63
+ if not os.path.isdir(checkpoint_dir):
64
+ raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
65
+
66
+ # there should be only one file
67
+ if zero_stage == 2:
68
+ file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
69
+ elif zero_stage == 3:
70
+ file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
71
+
72
+ if not os.path.exists(file):
73
+ raise FileNotFoundError(f"can't find model states file at '{file}'")
74
+
75
+ return file
76
+
77
+
78
+ def get_checkpoint_files(checkpoint_dir, glob_pattern):
79
+ # XXX: need to test that this simple glob rule works for multi-node setup too
80
+ ckpt_files = sorted(glob.glob(os.path.join(checkpoint_dir, glob_pattern)), key=natural_keys)
81
+
82
+ if len(ckpt_files) == 0:
83
+ raise FileNotFoundError(f"can't find {glob_pattern} files in directory '{checkpoint_dir}'")
84
+
85
+ return ckpt_files
86
+
87
+
88
+ def get_optim_files(checkpoint_dir):
89
+ return get_checkpoint_files(checkpoint_dir, "*_optim_states.pt")
90
+
91
+
92
+ def get_model_state_files(checkpoint_dir):
93
+ return get_checkpoint_files(checkpoint_dir, "*_model_states.pt")
94
+
95
+
96
+ def parse_model_states(files):
97
+ zero_model_states = []
98
+ for file in files:
99
+ state_dict = torch.load(file, map_location=device)
100
+
101
+ if BUFFER_NAMES not in state_dict:
102
+ raise ValueError(f"{file} is not a model state checkpoint")
103
+ buffer_names = state_dict[BUFFER_NAMES]
104
+ if debug:
105
+ print("Found buffers:", buffer_names)
106
+
107
+ # recover just the buffers while restoring them to fp32 if they were saved in fp16
108
+ buffers = {k: v.float() for k, v in state_dict["module"].items() if k in buffer_names}
109
+ param_shapes = state_dict[PARAM_SHAPES]
110
+
111
+ # collect parameters that are included in param_shapes
112
+ param_names = []
113
+ for s in param_shapes:
114
+ for name in s.keys():
115
+ param_names.append(name)
116
+
117
+ # update with frozen parameters
118
+ frozen_param_shapes = state_dict.get(FROZEN_PARAM_SHAPES, None)
119
+ if frozen_param_shapes is not None:
120
+ if debug:
121
+ print(f"Found frozen_param_shapes: {frozen_param_shapes}")
122
+ param_names += list(frozen_param_shapes.keys())
123
+
124
+ # record shared parameters so that they can be recovered based on partners
125
+ # this is because such parameters holding reference only are not saved by optimizer
126
+ shared_params = []
127
+ for param in state_dict["module"]:
128
+ if param not in [*param_names, *buffer_names]:
129
+ for share_param in state_dict["module"]:
130
+ if (state_dict["module"][share_param].data_ptr() == state_dict["module"][param].data_ptr()
131
+ and share_param != param):
132
+ shared_params.append([param, share_param])
133
+ break
134
+
135
+ ds_version = state_dict.get(DS_VERSION, None)
136
+
137
+ frozen_param_fragments = state_dict.get(FROZEN_PARAM_FRAGMENTS, None)
138
+
139
+ z_model_state = zero_model_state(buffers=buffers,
140
+ param_shapes=param_shapes,
141
+ shared_params=shared_params,
142
+ ds_version=ds_version,
143
+ frozen_param_shapes=frozen_param_shapes,
144
+ frozen_param_fragments=frozen_param_fragments)
145
+ zero_model_states.append(z_model_state)
146
+
147
+ return zero_model_states
148
+
149
+
150
+ def parse_optim_states(files, ds_checkpoint_dir):
151
+
152
+ total_files = len(files)
153
+ state_dicts = []
154
+ for i, f in enumerate(tqdm(files)):
155
+ state_dicts.append(torch.load(f, map_location=device))
156
+ if i == 0:
157
+ if not ZERO_STAGE in state_dicts[0][OPTIMIZER_STATE_DICT]:
158
+ raise ValueError(f"{files[0]} is not a zero checkpoint")
159
+ zero_stage = state_dicts[0][OPTIMIZER_STATE_DICT][ZERO_STAGE]
160
+ world_size = state_dicts[0][OPTIMIZER_STATE_DICT][PARTITION_COUNT]
161
+
162
+ # For ZeRO-2 each param group can have different partition_count as data parallelism for expert
163
+ # parameters can be different from data parallelism for non-expert parameters. So we can just
164
+ # use the max of the partition_count to get the dp world_size.
165
+
166
+ if type(world_size) is list:
167
+ world_size = max(world_size)
168
+
169
+ if world_size != total_files:
170
+ raise ValueError(
171
+ f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
172
+ "Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
173
+ )
174
+
175
+ # the groups are named differently in each stage
176
+ if zero_stage == 2:
177
+ fp32_groups_key = SINGLE_PARTITION_OF_FP32_GROUPS
178
+ elif zero_stage == 3:
179
+ fp32_groups_key = FP32_FLAT_GROUPS
180
+ else:
181
+ raise ValueError(f"unknown zero stage {zero_stage}")
182
+
183
+ key_list = list(state_dicts[-1][OPTIMIZER_STATE_DICT].keys())
184
+ for key in key_list:
185
+ if zero_stage == 2:
186
+ if key != fp32_groups_key:
187
+ del state_dicts[-1][OPTIMIZER_STATE_DICT][key]
188
+ elif zero_stage == 3:
189
+ if key == fp32_groups_key:
190
+ value = torch.cat(state_dicts[-1][OPTIMIZER_STATE_DICT][fp32_groups_key], 0)
191
+ del state_dicts[-1][OPTIMIZER_STATE_DICT][key]
192
+ if key == fp32_groups_key:
193
+ state_dicts[-1][OPTIMIZER_STATE_DICT][key] = value
194
+
195
+ print('zero_stage:', zero_stage)
196
+ fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
197
+ # if zero_stage == 2:
198
+ # # fp32_flat_groups = [state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key] for i in range(len(state_dicts))]
199
+ # elif zero_stage == 3:
200
+ # # if there is more than one param group, there will be multiple flattened tensors - one
201
+ # # flattened tensor per group - for simplicity merge them into a single tensor
202
+ # #
203
+ # # XXX: could make the script more memory efficient for when there are multiple groups - it
204
+ # # will require matching the sub-lists of param_shapes for each param group flattened tensor
205
+
206
+ # print('start!')
207
+ # # fp32_flat_groups = [
208
+ # # torch.cat(state_dicts[i][OPTIMIZER_STATE_DICT][fp32_groups_key], 0) for i in range(len(state_dicts))
209
+ # # ]
210
+
211
+ return zero_stage, world_size, fp32_flat_groups
212
+
213
+
214
+ def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
215
+ """
216
+ Returns fp32 state_dict reconstructed from ds checkpoint
217
+
218
+ Args:
219
+ - ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
220
+
221
+ """
222
+ print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
223
+
224
+ optim_files = get_optim_files(ds_checkpoint_dir)
225
+ zero_stage, world_size, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
226
+ print(f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
227
+
228
+ model_files = get_model_state_files(ds_checkpoint_dir)
229
+
230
+ zero_model_states = parse_model_states(model_files)
231
+ print(f'Parsing checkpoint created by deepspeed=={zero_model_states[0].ds_version}')
232
+
233
+ if zero_stage == 2:
234
+ return _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states)
235
+ elif zero_stage == 3:
236
+ return _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states)
237
+
238
+
239
+ def _zero2_merge_frozen_params(state_dict, zero_model_states):
240
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
241
+ return
242
+
243
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
244
+ frozen_param_fragments = zero_model_states[0].frozen_param_fragments
245
+
246
+ if debug:
247
+ num_elem = sum(s.numel() for s in frozen_param_shapes.values())
248
+ print(f'rank 0: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
249
+
250
+ wanted_params = len(frozen_param_shapes)
251
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
252
+ avail_numel = sum([p.numel() for p in frozen_param_fragments.values()])
253
+ print(f'Frozen params: Have {avail_numel} numels to process.')
254
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
255
+
256
+ total_params = 0
257
+ total_numel = 0
258
+ for name, shape in frozen_param_shapes.items():
259
+ total_params += 1
260
+ unpartitioned_numel = shape.numel()
261
+ total_numel += unpartitioned_numel
262
+
263
+ state_dict[name] = frozen_param_fragments[name]
264
+
265
+ if debug:
266
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
267
+
268
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
269
+
270
+
271
+ def _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
272
+ param_shapes = zero_model_states[0].param_shapes
273
+
274
+ # Reconstruction protocol:
275
+ #
276
+ # XXX: document this
277
+
278
+ if debug:
279
+ for i in range(world_size):
280
+ for j in range(len(fp32_flat_groups[0])):
281
+ print(f"{FP32_FLAT_GROUPS}[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
282
+
283
+ # XXX: memory usage doubles here (zero2)
284
+ num_param_groups = len(fp32_flat_groups[0])
285
+ merged_single_partition_of_fp32_groups = []
286
+ for i in range(num_param_groups):
287
+ merged_partitions = [sd[i] for sd in fp32_flat_groups]
288
+ full_single_fp32_vector = torch.cat(merged_partitions, 0)
289
+ merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
290
+ avail_numel = sum(
291
+ [full_single_fp32_vector.numel() for full_single_fp32_vector in merged_single_partition_of_fp32_groups])
292
+
293
+ if debug:
294
+ wanted_params = sum([len(shapes) for shapes in param_shapes])
295
+ wanted_numel = sum([sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
296
+ # not asserting if there is a mismatch due to possible padding
297
+ print(f"Have {avail_numel} numels to process.")
298
+ print(f"Need {wanted_numel} numels in {wanted_params} params.")
299
+
300
+ # params
301
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
302
+ # out-of-core computing solution
303
+ total_numel = 0
304
+ total_params = 0
305
+ for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
306
+ offset = 0
307
+ avail_numel = full_single_fp32_vector.numel()
308
+ for name, shape in shapes.items():
309
+
310
+ unpartitioned_numel = shape.numel()
311
+ total_numel += unpartitioned_numel
312
+ total_params += 1
313
+
314
+ if debug:
315
+ print(f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} ")
316
+ state_dict[name] = full_single_fp32_vector.narrow(0, offset, unpartitioned_numel).view(shape)
317
+ offset += unpartitioned_numel
318
+
319
+ # Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
320
+ # avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
321
+ # paddings performed in the code it's almost impossible to predict the exact numbers w/o the
322
+ # live optimizer object, so we are checking that the numbers are within the right range
323
+ align_to = 2 * world_size
324
+
325
+ def zero2_align(x):
326
+ return align_to * math.ceil(x / align_to)
327
+
328
+ if debug:
329
+ print(f"original offset={offset}, avail_numel={avail_numel}")
330
+
331
+ offset = zero2_align(offset)
332
+ avail_numel = zero2_align(avail_numel)
333
+
334
+ if debug:
335
+ print(f"aligned offset={offset}, avail_numel={avail_numel}")
336
+
337
+ # Sanity check
338
+ if offset != avail_numel:
339
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
340
+
341
+ print(f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements")
342
+
343
+
344
+ def _get_fp32_state_dict_from_zero2_checkpoint(world_size, fp32_flat_groups, zero_model_states):
345
+ state_dict = OrderedDict()
346
+
347
+ # buffers
348
+ buffers = zero_model_states[0].buffers
349
+ state_dict.update(buffers)
350
+ if debug:
351
+ print(f"added {len(buffers)} buffers")
352
+
353
+ _zero2_merge_frozen_params(state_dict, zero_model_states)
354
+
355
+ _zero2_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
356
+
357
+ # recover shared parameters
358
+ for pair in zero_model_states[0].shared_params:
359
+ state_dict[pair[0]] = state_dict[pair[1]]
360
+
361
+ return state_dict
362
+
363
+
364
+ def zero3_partitioned_param_info(unpartitioned_numel, world_size):
365
+ remainder = unpartitioned_numel % world_size
366
+ padding_numel = (world_size - remainder) if remainder else 0
367
+ partitioned_numel = math.ceil(unpartitioned_numel / world_size)
368
+ return partitioned_numel, padding_numel
369
+
370
+
371
+ def _zero3_merge_frozen_params(state_dict, world_size, zero_model_states):
372
+ if zero_model_states[0].frozen_param_shapes is None or len(zero_model_states[0].frozen_param_shapes) == 0:
373
+ return
374
+
375
+ if debug:
376
+ for i in range(world_size):
377
+ num_elem = sum(s.numel() for s in zero_model_states[i].frozen_param_fragments.values())
378
+ print(f'rank {i}: {FROZEN_PARAM_SHAPES}.numel = {num_elem}')
379
+
380
+ frozen_param_shapes = zero_model_states[0].frozen_param_shapes
381
+ wanted_params = len(frozen_param_shapes)
382
+ wanted_numel = sum(s.numel() for s in frozen_param_shapes.values())
383
+ avail_numel = sum([p.numel() for p in zero_model_states[0].frozen_param_fragments.values()]) * world_size
384
+ print(f'Frozen params: Have {avail_numel} numels to process.')
385
+ print(f'Frozen params: Need {wanted_numel} numels in {wanted_params} params')
386
+
387
+ total_params = 0
388
+ total_numel = 0
389
+ for name, shape in tqdm(zero_model_states[0].frozen_param_shapes.items()):
390
+ total_params += 1
391
+ unpartitioned_numel = shape.numel()
392
+ total_numel += unpartitioned_numel
393
+
394
+ param_frags = tuple(model_state.frozen_param_fragments[name] for model_state in zero_model_states)
395
+ state_dict[name] = torch.cat(param_frags, 0).narrow(0, 0, unpartitioned_numel).view(shape)
396
+
397
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
398
+
399
+ if debug:
400
+ print(
401
+ f"Frozen params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
402
+ )
403
+
404
+ print(f"Reconstructed Frozen fp32 state dict with {total_params} params {total_numel} elements")
405
+
406
+
407
+ def _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states):
408
+ param_shapes = zero_model_states[0].param_shapes
409
+ avail_numel = fp32_flat_groups[0].numel() * world_size
410
+ # Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
411
+ # param, re-consolidating each param, while dealing with padding if any
412
+
413
+ # merge list of dicts, preserving order
414
+ param_shapes = {k: v for d in param_shapes for k, v in d.items()}
415
+
416
+ if debug:
417
+ for i in range(world_size):
418
+ print(f"{FP32_FLAT_GROUPS}[{i}].shape={fp32_flat_groups[i].shape}")
419
+
420
+ wanted_params = len(param_shapes)
421
+ wanted_numel = sum(shape.numel() for shape in param_shapes.values())
422
+ # not asserting if there is a mismatch due to possible padding
423
+ avail_numel = fp32_flat_groups[0].numel() * world_size
424
+ print(f"Trainable params: Have {avail_numel} numels to process.")
425
+ print(f"Trainable params: Need {wanted_numel} numels in {wanted_params} params.")
426
+
427
+ # params
428
+ # XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
429
+ # out-of-core computing solution
430
+ offset = 0
431
+ total_numel = 0
432
+ total_params = 0
433
+ for name, shape in tqdm(param_shapes.items()):
434
+
435
+ unpartitioned_numel = shape.numel()
436
+ total_numel += unpartitioned_numel
437
+ total_params += 1
438
+
439
+ partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
440
+
441
+ if debug:
442
+ print(
443
+ f"Trainable params: {total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
444
+ )
445
+
446
+ # XXX: memory usage doubles here
447
+ state_dict[name] = torch.cat(
448
+ tuple(fp32_flat_groups[i].narrow(0, offset, partitioned_numel) for i in range(world_size)),
449
+ 0).narrow(0, 0, unpartitioned_numel).view(shape)
450
+ offset += partitioned_numel
451
+
452
+ offset *= world_size
453
+
454
+ # Sanity check
455
+ if offset != avail_numel:
456
+ raise ValueError(f"consumed {offset} numels out of {avail_numel} - something is wrong")
457
+
458
+ print(f"Reconstructed Trainable fp32 state dict with {total_params} params {total_numel} elements")
459
+
460
+
461
+ def _get_fp32_state_dict_from_zero3_checkpoint(world_size, fp32_flat_groups, zero_model_states):
462
+ state_dict = OrderedDict()
463
+
464
+ # buffers
465
+ buffers = zero_model_states[0].buffers
466
+ state_dict.update(buffers)
467
+ if debug:
468
+ print(f"added {len(buffers)} buffers")
469
+
470
+ _zero3_merge_frozen_params(state_dict, world_size, zero_model_states)
471
+
472
+ _zero3_merge_trainable_params(state_dict, world_size, fp32_flat_groups, zero_model_states)
473
+
474
+ # recover shared parameters
475
+ for pair in zero_model_states[0].shared_params:
476
+ state_dict[pair[0]] = state_dict[pair[1]]
477
+
478
+ return state_dict
479
+
480
+
481
+ def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
482
+ """
483
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
484
+ ``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
485
+ via a model hub.
486
+
487
+ Args:
488
+ - ``checkpoint_dir``: path to the desired checkpoint folder
489
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
490
+
491
+ Returns:
492
+ - pytorch ``state_dict``
493
+
494
+ Note: this approach may not work if your application doesn't have sufficient free CPU memory and
495
+ you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
496
+ the checkpoint.
497
+
498
+ A typical usage might be ::
499
+
500
+ from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
501
+ # do the training and checkpoint saving
502
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
503
+ model = model.cpu() # move to cpu
504
+ model.load_state_dict(state_dict)
505
+ # submit to model hub or save the model to share with others
506
+
507
+ In this example the ``model`` will no longer be usable in the deepspeed context of the same
508
+ application. i.e. you will need to re-initialize the deepspeed engine, since
509
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
510
+
511
+ If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
512
+
513
+ """
514
+ if tag is None:
515
+ latest_path = os.path.join(checkpoint_dir, 'latest')
516
+ if os.path.isfile(latest_path):
517
+ with open(latest_path, 'r') as fd:
518
+ tag = fd.read().strip()
519
+ else:
520
+ raise ValueError(f"Unable to find 'latest' file at {latest_path}")
521
+
522
+ ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
523
+
524
+ if not os.path.isdir(ds_checkpoint_dir):
525
+ raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
526
+
527
+ return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
528
+
529
+
530
+ def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
531
+ """
532
+ Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
533
+ loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
534
+
535
+ Args:
536
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
537
+ - ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
538
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
539
+ """
540
+
541
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
542
+ print(f"Saving fp32 state dict to {output_file}")
543
+ torch.save(state_dict, output_file)
544
+
545
+
546
+ def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
547
+ """
548
+ 1. Put the provided model to cpu
549
+ 2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
550
+ 3. Load it into the provided model
551
+
552
+ Args:
553
+ - ``model``: the model object to update
554
+ - ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
555
+ - ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
556
+
557
+ Returns:
558
+ - ``model`: modified model
559
+
560
+ Make sure you have plenty of CPU memory available before you call this function. If you don't
561
+ have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
562
+ conveniently placed for you in the checkpoint folder.
563
+
564
+ A typical usage might be ::
565
+
566
+ from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
567
+ model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
568
+ # submit to model hub or save the model to share with others
569
+
570
+ Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
571
+ of the same application. i.e. you will need to re-initialize the deepspeed engine, since
572
+ ``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
573
+
574
+ """
575
+ logger.info(f"Extracting fp32 weights")
576
+ state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
577
+
578
+ logger.info(f"Overwriting model with fp32 weights")
579
+ model = model.cpu()
580
+ model.load_state_dict(state_dict, strict=False)
581
+
582
+ return model
583
+
584
+
585
+ if __name__ == "__main__":
586
+
587
+ parser = argparse.ArgumentParser()
588
+ parser.add_argument("checkpoint_dir",
589
+ type=str,
590
+ help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
591
+ parser.add_argument(
592
+ "output_file",
593
+ type=str,
594
+ help="path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)")
595
+ parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
596
+ args = parser.parse_args()
597
+
598
+ debug = args.debug
599
+
600
+ convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
data.py ADDED
@@ -0,0 +1,1035 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import argparse
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import Dataset
8
+ from tqdm import tqdm
9
+ from collections import defaultdict
10
+ import torch.distributed as dist
11
+ import logging
12
+ import re
13
+ import pdb
14
+ import json
15
+ from prompt import sft_prompt, all_prompt
16
+ import numpy as np
17
+
18
+ class BaseDataset(Dataset):
19
+ def __init__(self, args):
20
+ super().__init__()
21
+
22
+ self.args = args
23
+ self.dataset = args.dataset
24
+ self.data_path = os.path.join(args.data_path, self.dataset)
25
+
26
+ self.max_his_len = args.max_his_len
27
+ self.his_sep = args.his_sep
28
+ self.index_file = args.index_file
29
+ self.user_index_file = args.user_index_file
30
+ self.add_prefix = args.add_prefix
31
+
32
+ self.new_tokens = None
33
+ self.allowed_tokens = None
34
+ self.all_items = None
35
+
36
+ def _load_data(self):
37
+ with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
38
+ self.indices = json.load(f)
39
+
40
+ def get_new_tokens(self):
41
+ if self.new_tokens is not None:
42
+ return self.new_tokens
43
+
44
+ self.new_tokens = set()
45
+ for index in self.indices.values():
46
+ for token in index:
47
+ self.new_tokens.add(token)
48
+ self.new_tokens = sorted(list(self.new_tokens))
49
+
50
+ return self.new_tokens
51
+
52
+ def get_all_items(self):
53
+ if self.all_items is not None:
54
+ return self.all_items
55
+
56
+ self.all_items = set()
57
+ for index in self.indices.values():
58
+ self.all_items.add("".join(index))
59
+
60
+ return self.all_items
61
+
62
+ def get_prefix_allowed_tokens_fn(self, tokenizer):
63
+ if self.allowed_tokens is None:
64
+ self.allowed_tokens = {}
65
+ for index in self.indices.values():
66
+ for i, token in enumerate(index):
67
+ token_id = tokenizer(token)["input_ids"][1]
68
+ if i not in self.allowed_tokens.keys():
69
+ self.allowed_tokens[i] = set()
70
+ self.allowed_tokens[i].add(token_id)
71
+ self.allowed_tokens[len(self.allowed_tokens.keys())] = set([tokenizer.eos_token_id])
72
+ sep = tokenizer("Response:")["input_ids"][1:]
73
+
74
+ def prefix_allowed_tokens_fn(batch_id, sentence):
75
+ sentence = sentence.tolist()
76
+ reversed_sent = sentence[::-1]
77
+ for i in range(len(reversed_sent)):
78
+ if reversed_sent[i:i + len(sep)] == sep[::-1]:
79
+ return list(self.allowed_tokens[i])
80
+
81
+ return prefix_allowed_tokens_fn
82
+
83
+ def _process_data(self):
84
+ raise NotImplementedError
85
+
86
+ class UserFeatDataset(BaseDataset):
87
+ def __init__(self, args, task = "pref2user", prompt_sample_num = 1, sample_num = -1):
88
+ super().__init__(args)
89
+
90
+ self.task = task.lower()
91
+ self.prompt_sample_num = prompt_sample_num
92
+ self.sample_num = sample_num
93
+
94
+ self.prompts = all_prompt[self.task]
95
+
96
+ self._load_data()
97
+ self.feat_data = self._process_data()
98
+
99
+ def _load_data(self):
100
+ with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
101
+ user_feat = json.load(f)
102
+ # >>> user_feat.keys()
103
+ # dict_keys(['user_explicit_preference', 'user_vague_intention'])
104
+ self.user_feat = user_feat['user_explicit_preference']
105
+ # >>> user_feat['0']
106
+ # ['The user is a passionate musician who enjoys exploring different types of musical instruments.']
107
+ # >>> len(user_feat)
108
+ # 24772
109
+
110
+ def _process_data(self):
111
+ feat_data = []
112
+ for uid in self.user_feat:
113
+ one_data = {}
114
+ one_data['user'] = uid
115
+
116
+ preference = " ".join(self.user_feat[uid])
117
+ preference = preference.strip().strip(".!?,;:`")
118
+ preference = preference.replace('{','').replace('}','')
119
+ one_data['preference'] = preference
120
+
121
+ feat_data.append(one_data)
122
+
123
+ if self.sample_num > 0:
124
+ all_idx = range(len(feat_data))
125
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace = False)
126
+ feat_data = np.array(feat_data)[sample_idx].tolist()
127
+
128
+ return feat_data
129
+
130
+ def __len__(self):
131
+ return len(self.feat_data) * self.prompt_sample_num
132
+
133
+ def __getitem__(self, index):
134
+ idx = index // self.prompt_sample_num
135
+ d = self.feat_data[idx]
136
+ prompt_id = random.randint(0, len(self.prompts) - 1)
137
+ prompt = self.prompts[prompt_id]
138
+
139
+ if self.task == 'pref2user':
140
+ instruction = prompt['instruction'].format(preference = d['preference'])
141
+ input = sft_prompt.format(instruction = instruction, response = "")
142
+ output = sft_prompt.format(instruction = instruction, response = prompt["response"])
143
+ return dict(
144
+ input_ids = input,
145
+ labels = output,
146
+ inters = 'placeholder',
147
+ item = 'placeholder',
148
+ users = 'placeholder',
149
+ user = d['user'],
150
+ task = self.task
151
+ )
152
+ elif self.task == 'user2pref':
153
+ input = sft_prompt.format(instruction = prompt["instruction"], response = "")
154
+ response = prompt["response"].format(preference = d['preference'])
155
+ output = sft_prompt.format(instruction = prompt["instruction"], response = response)
156
+ return dict(
157
+ input_ids = input,
158
+ labels = output,
159
+ inters = 'placeholder',
160
+ item = 'placeholder',
161
+ users = 'placeholder',
162
+ user = d['user'],
163
+ task = self.task
164
+ )
165
+ else:
166
+ raise NotImplementedError
167
+
168
+ class UserSearchDataset(BaseDataset):
169
+ def __init__(self, args, prompt_sample_num = 1, prompt_id = 0, sample_num = -1):
170
+ super().__init__(args)
171
+
172
+ self.prompt_sample_num = prompt_sample_num
173
+ self.prompt_id = prompt_id
174
+ self.sample_num = sample_num
175
+
176
+ self.prompts = all_prompt["usersearch"]
177
+
178
+ self._load_data()
179
+ self.search_data = self._process_data()
180
+
181
+ def _load_data(self):
182
+ with open(os.path.join(self.data_path, self.dataset + ".inter.user.json"), 'r') as f:
183
+ self.user_inters = json.load(f)
184
+
185
+ def _process_data(self):
186
+ search_data = []
187
+ for iid in self.user_inters.keys():
188
+ users = self.user_inters[iid]
189
+ for i in range(1, len(users)):
190
+ one_data = {}
191
+ one_data['item'] = iid
192
+ one_data['user'] = str(users[i])
193
+ history = users[:i]
194
+
195
+ if len(history) > self.max_his_len:
196
+ history = history[-self.max_his_len:]
197
+
198
+ one_data['users'] = ''
199
+ for user in history:
200
+ one_data['users'] = one_data['users'] + str(user) + ','
201
+ one_data['users'] = one_data['users'][:-1]
202
+
203
+ search_data.append(one_data)
204
+
205
+ if self.sample_num > 0:
206
+ all_idx = range(len(search_data))
207
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace = False)
208
+ search_data = np.array(search_data)[sample_idx].tolist()
209
+
210
+ return search_data
211
+
212
+ def __len__(self):
213
+ return len(self.search_data) * self.prompt_sample_num
214
+
215
+ def __getitem__(self, index):
216
+ idx = index // self.prompt_sample_num
217
+ d = self.search_data[idx]
218
+
219
+ prompt_id = random.randint(0, len(self.prompts) - 1)
220
+ prompt = self.prompts[prompt_id]
221
+
222
+ input = sft_prompt.format(instruction = prompt["instruction"], response = "")
223
+ output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"])
224
+ return dict(
225
+ input_ids = input,
226
+ labels = output,
227
+ inters = 'placeholder',
228
+ item = d['item'],
229
+ users = d['users'],
230
+ user = d['user'],
231
+ task = 'usersearch'
232
+ )
233
+
234
+ # =====================================================================================================================
235
+ # seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item
236
+ # =====================================================================================================================
237
+
238
+ # seqrec
239
+ class SeqRecDataset(BaseDataset):
240
+ def __init__(self, args, mode="train",
241
+ prompt_sample_num=1, prompt_id=0, sample_num=-1):
242
+ super().__init__(args)
243
+
244
+ self.mode = mode
245
+ self.prompt_sample_num = prompt_sample_num
246
+ self.prompt_id = prompt_id
247
+ self.sample_num = sample_num
248
+
249
+ self.prompts = all_prompt["seqrec"]
250
+
251
+ self._load_data()
252
+
253
+ if self.mode == 'train':
254
+ self.inter_data = self._process_train_data()
255
+ # self.inter_data = self.inter_data[:10]
256
+ elif self.mode == 'valid':
257
+ self.sample_valid = args.sample_valid
258
+ self.valid_prompt_id = args.valid_prompt_id
259
+ self.inter_data = self._process_valid_data()
260
+ # self.inter_data = self.inter_data[:10]
261
+ self._construct_valid_text()
262
+ elif self.mode == 'test':
263
+ self.inter_data = self._process_test_data()
264
+ # self.inter_data = self.inter_data[:10]
265
+ else:
266
+ raise NotImplementedError
267
+
268
+ def _load_data(self):
269
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
270
+ self.inters = json.load(f)
271
+
272
+ def _process_train_data(self):
273
+ inter_data = []
274
+ for uid in self.inters:
275
+ items = self.inters[uid][:-2]
276
+ for i in range(1, len(items)):
277
+ one_data = dict()
278
+ one_data['user'] = uid
279
+ one_data['item'] = str(items[i])
280
+ history = items[:i]
281
+ if self.max_his_len > 0:
282
+ history = history[-self.max_his_len:]
283
+ one_data['inters'] = ''
284
+ for item in history:
285
+ one_data['inters'] = one_data['inters'] + str(item) + ','
286
+ one_data['inters'] = one_data['inters'][:-1]
287
+
288
+ inter_data.append(one_data)
289
+
290
+ return inter_data
291
+
292
+ def _process_valid_data(self):
293
+ inter_data = []
294
+ for uid in self.inters:
295
+ one_data = dict()
296
+ items = self.inters[uid]
297
+ one_data['user'] = uid
298
+ one_data['item'] = str(items[-2])
299
+ history = items[:-2]
300
+ if self.max_his_len > 0:
301
+ history = history[-self.max_his_len:]
302
+ one_data['inters'] = ''
303
+ for item in history:
304
+ one_data['inters'] = one_data['inters'] + str(item) + ','
305
+ one_data['inters'] = one_data['inters'][:-1]
306
+ inter_data.append(one_data)
307
+
308
+ return inter_data
309
+
310
+ def _process_test_data(self):
311
+ with open(self.index_file, 'r') as f:
312
+ self.indices = json.load(f)
313
+ self.remapped_inters = dict()
314
+ for uid, items in self.inters.items():
315
+ new_items = ["".join(self.indices[str(i)]) for i in items]
316
+ self.remapped_inters[uid] = new_items
317
+
318
+ with open(self.user_index_file, 'r') as f:
319
+ self.user_indices = json.load(f)
320
+ self.remapped_users = dict()
321
+ for uid in self.inters:
322
+ new_user= ''.join(self.user_indices[uid])
323
+ self.remapped_users[uid] = new_user
324
+
325
+ inter_data = []
326
+ for uid in self.remapped_inters:
327
+ one_data = dict()
328
+ one_data['user'] = self.remapped_users[uid]
329
+ items = self.remapped_inters[uid]
330
+ one_data['item'] = items[-1]
331
+ history = items[:-1]
332
+ if self.max_his_len > 0:
333
+ history = history[-self.max_his_len:]
334
+ one_data["inters"] = self.his_sep.join(history)
335
+ inter_data.append(one_data)
336
+
337
+ # for uid in self.inters:
338
+ # one_data = dict()
339
+ # items = self.inters[uid]
340
+ # one_data["item"] = str(items[-1])
341
+ # history = items[:-1]
342
+ # if self.max_his_len > 0:
343
+ # history = history[-self.max_his_len:]
344
+ # one_data['inters'] = ''
345
+ # for item in history:
346
+ # one_data['inters'] = one_data['inters'] + str(item) + ','
347
+ # one_data['inters'] = one_data['inters'][:-1]
348
+ # inter_data.append(one_data)
349
+
350
+ if self.sample_num > 0:
351
+ all_inter_idx = range(len(inter_data))
352
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace = False)
353
+ inter_data = np.array(inter_data)[sample_idx].tolist()
354
+
355
+ return inter_data
356
+
357
+ def set_prompt(self, prompt_id):
358
+ self.prompt_id = prompt_id
359
+
360
+ def __len__(self):
361
+ if self.mode == 'train':
362
+ return len(self.inter_data) * self.prompt_sample_num
363
+ elif self.mode == 'valid':
364
+ return len(self.valid_text_data)
365
+ elif self.mode == 'test':
366
+ return len(self.inter_data)
367
+ else:
368
+ raise NotImplementedError
369
+
370
+ def _construct_valid_text(self):
371
+ self.valid_text_data = []
372
+ if self.sample_valid:
373
+ all_prompt_ids = range(len(self.prompts))
374
+ for i in range(len(self.inter_data)):
375
+ d = self.inter_data[i]
376
+ prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
377
+ for prompt_id in prompt_ids:
378
+ prompt = self.prompts[prompt_id]
379
+ input = sft_prompt.format(instruction = prompt["instruction"], response = "")
380
+ output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"])
381
+ self.valid_text_data.append({
382
+ "input_ids": input,
383
+ "labels": output,
384
+ "inters": d['inters'],
385
+ "item": d['item'],
386
+ "users": 'placeholder',
387
+ "user": d['user'],
388
+ "task": 'seqrec'})
389
+ else:
390
+ self.prompt_sample_num = 1
391
+ prompt = self.prompts[self.valid_prompt_id]
392
+ for i in range(len(self.inter_data)):
393
+ d = self.inter_data[i]
394
+ input = sft_prompt.format(instruction = prompt["instruction"], response = "")
395
+ output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"])
396
+ self.valid_text_data.append({
397
+ "input_ids": input, "labels": output,
398
+ "inters": d['inters'], "item": d['item'], "users": 'placeholder', "user": d['user'],
399
+ "task": 'seqrec'})
400
+
401
+ def _get_text_data(self, data, prompt):
402
+ instruction = prompt["instruction"].format(**data)
403
+ response = prompt["response"].format(**data)
404
+
405
+ input = sft_prompt.format(instruction = instruction, response = "")
406
+ output = sft_prompt.format(instruction = instruction, response = response)
407
+
408
+ if self.mode == 'test':
409
+ return input, response
410
+
411
+ return input, output
412
+
413
+ def __getitem__(self, index):
414
+ if self.mode == 'valid':
415
+ return self.valid_text_data[index]
416
+
417
+ idx = index // self.prompt_sample_num
418
+ d = self.inter_data[idx]
419
+
420
+ if self.mode == 'train':
421
+ prompt_id = random.randint(0, len(self.prompts) - 1)
422
+ elif self.mode == 'test':
423
+ prompt_id = self.prompt_id
424
+ prompt = self.prompts[prompt_id]
425
+ instruction = prompt["instruction"].format(**d)
426
+ response = prompt["response"].format(**d)
427
+ input = sft_prompt.format(instruction = instruction, response = "")
428
+ return dict(input_ids = input, labels = response)
429
+ # output = prompt["response"]
430
+ # return dict(input_ids = input, labels = output, inters = d['inters'], item = d['item'], task = 'seqrec')
431
+
432
+ prompt = self.prompts[prompt_id]
433
+
434
+ input = sft_prompt.format(instruction = prompt["instruction"], response = "")
435
+ output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"])
436
+
437
+ return dict(input_ids = input, labels = output, inters = d['inters'], item = d['item'], user = d['user'], task = 'seqrec', users = 'placeholder')
438
+
439
+ # itemsearch & query2item
440
+ class ItemSearchDataset(BaseDataset):
441
+ def __init__(self, args, mode="train", task = 'itemsearch',
442
+ prompt_sample_num=1, prompt_id=0, sample_num=-1):
443
+ super().__init__(args)
444
+
445
+ self.mode = mode
446
+ self.prompt_sample_num = prompt_sample_num
447
+ self.prompt_id = prompt_id
448
+ self.sample_num = sample_num
449
+
450
+ self.task = task.lower()
451
+ self.prompts = all_prompt[self.task]
452
+
453
+ self._load_data()
454
+ self.search_data = self._process_data()
455
+
456
+ def _load_data(self):
457
+ with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
458
+ self.user_info = json.load(f)
459
+
460
+ def _process_data(self):
461
+ search_data = []
462
+ user_explicit_preference = self.user_info["user_explicit_preference"]
463
+ user_vague_intention = self.user_info["user_vague_intention"]
464
+ if self.mode == 'train':
465
+ user_vague_intention = user_vague_intention["train"]
466
+ elif self.mode == 'test':
467
+ user_vague_intention = user_vague_intention["test"]
468
+ else:
469
+ raise NotImplementedError
470
+
471
+ for uid in user_explicit_preference.keys():
472
+ one_data = {}
473
+ one_data['user'] = uid
474
+ user_ep = user_explicit_preference[uid]
475
+ user_vi = user_vague_intention[uid]["querys"]
476
+ one_data["explicit_preferences"] = user_ep
477
+ one_data["user_related_intention"] = user_vi[0]
478
+ one_data["item_related_intention"] = user_vi[1]
479
+
480
+ iid = user_vague_intention[uid]["item"]
481
+ inters = user_vague_intention[uid]["inters"]
482
+
483
+ if len(inters) == 0:
484
+ continue
485
+
486
+ one_data["item"] = str(iid)
487
+
488
+ if self.max_his_len > 0:
489
+ inters = inters[-self.max_his_len:]
490
+ one_data["inters"] = ''
491
+ for item in inters:
492
+ one_data["inters"] = one_data["inters"] + str(item) + ','
493
+ one_data["inters"] = one_data["inters"][:-1]
494
+
495
+ search_data.append(one_data)
496
+
497
+ if self.sample_num > 0:
498
+ all_idx = range(len(search_data))
499
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
500
+ search_data = np.array(search_data)[sample_idx].tolist()
501
+
502
+ return search_data
503
+
504
+ def set_prompt(self, prompt_id):
505
+ self.prompt_id = prompt_id
506
+
507
+ def __len__(self):
508
+ if self.mode == 'train':
509
+ return len(self.search_data) * self.prompt_sample_num
510
+ elif self.mode == 'test':
511
+ return len(self.search_data)
512
+ else:
513
+ return len(self.search_data)
514
+
515
+ def _get_text_data(self, data, prompt):
516
+ instruction = prompt["instruction"].format(**data)
517
+ response = prompt["response"].format(**data)
518
+
519
+ input = sft_prompt.format(instruction = instruction, response = "")
520
+ output = sft_prompt.format(instruction = instruction, response = response)
521
+
522
+ if self.mode == 'test':
523
+ return input, response
524
+
525
+ return input, output
526
+
527
+ def __getitem__(self, index):
528
+ idx = index // self.prompt_sample_num
529
+
530
+ d = self.search_data[idx]
531
+ if self.mode == 'train':
532
+ prompt_id = random.randint(0, len(self.prompts) - 1)
533
+ elif self.mode == 'test':
534
+ prompt_id = self.prompt_id
535
+
536
+ prompt = self.prompts[prompt_id]
537
+
538
+ d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
539
+ d["explicit_preference"] = d["explicit_preference"].replace('{','').replace('}','')
540
+ d["user_related_intention"] = d["user_related_intention"].replace('{','').replace('}','')
541
+ d["item_related_intention"] = d["item_related_intention"].replace('{','').replace('}','')
542
+ all_querys = [d["user_related_intention"], d["item_related_intention"]]
543
+ d["query"] = random.choice(all_querys)
544
+
545
+ # d["query"] = d["query"].replace('{','').replace('}','')
546
+
547
+ if self.task == 'itemsearch':
548
+ sub_d = d.copy()
549
+ sub_d.pop('inters')
550
+ sub_d.pop('user')
551
+ instruction = prompt["instruction"].format(inters='{inters}', user='{user}', **sub_d)
552
+ input = sft_prompt.format(instruction = instruction, response = "")
553
+ output = sft_prompt.format(instruction = instruction, response = prompt["response"])
554
+ return dict(input_ids = input, labels = output, inters = d['inters'], item = d['item'], user = d['user'], task = self.task, users = 'placeholder')
555
+ elif self.task == 'query2item':
556
+ sub_d = d.copy()
557
+ sub_d.pop('user')
558
+ instruction = prompt["instruction"].format(user='{user}', **sub_d)
559
+ input = sft_prompt.format(instruction = instruction, response = "")
560
+ output = sft_prompt.format(instruction = instruction, response = prompt["response"])
561
+ return dict(input_ids = input, labels = output, inters = 'placeholder', item = d['item'], user = d['user'], task = self.task, users = 'placeholder')
562
+
563
+ # inters2title & inters2description & intertitles2item
564
+ class FusionSeqRecDataset(BaseDataset):
565
+ def __init__(self, args, mode="train", task = 'inters2title',
566
+ prompt_sample_num=1, prompt_id=0, sample_num=-1):
567
+ super().__init__(args)
568
+
569
+ self.mode = mode
570
+ self.prompt_sample_num = prompt_sample_num
571
+ self.prompt_id = prompt_id
572
+ self.sample_num = sample_num
573
+
574
+ self.task = task.lower()
575
+ self.prompts = all_prompt[self.task]
576
+
577
+ # load data
578
+ self._load_data()
579
+
580
+ # load data
581
+ if self.mode == 'train':
582
+ self.inter_data = self._process_train_data()
583
+ elif self.mode == 'valid':
584
+ self.sample_valid = args.sample_valid
585
+ self.valid_prompt_id = args.valid_prompt_id
586
+ self.inter_data = self._process_valid_data()
587
+ self._construct_valid_text()
588
+ elif self.mode == 'test':
589
+ self.inter_data = self._process_test_data()
590
+ else:
591
+ raise NotImplementedError
592
+
593
+ def _load_data(self):
594
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
595
+ self.inters = json.load(f)
596
+ with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
597
+ self.item_feat = json.load(f)
598
+
599
+ def _process_train_data(self):
600
+
601
+ inter_data = []
602
+ for uid in self.inters:
603
+ items = self.inters[uid][:-2]
604
+ for i in range(1, len(items)):
605
+ one_data = dict()
606
+ one_data["item"] = str(items[i])
607
+ one_data['user'] = uid
608
+ one_data["title"] = self.item_feat[str(items[i])]["title"].strip().strip(".!?,;:`")
609
+ one_data["title"] = one_data["title"].replace('{','').replace('}','')
610
+ one_data["description"] = self.item_feat[str(items[i])]["description"]
611
+ one_data["description"] = one_data["description"].replace('{','').replace('}','')
612
+
613
+ history = items[:i]
614
+ if self.max_his_len > 0:
615
+ history = history[-self.max_his_len:]
616
+
617
+ one_data['inters'] = ''
618
+ for item in history:
619
+ one_data['inters'] = one_data['inters'] + str(item) +','
620
+ one_data['inters'] = one_data['inters'][:-1]
621
+
622
+ inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`").replace('{','').replace('}','') + "\"" for j in history]
623
+ one_data["inter_titles"] = self.his_sep.join(inter_titles)
624
+
625
+ inter_data.append(one_data)
626
+
627
+ if self.sample_num > 0:
628
+ all_inter_idx = range(len(inter_data))
629
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
630
+ inter_data = np.array(inter_data)[sample_idx].tolist()
631
+
632
+ return inter_data
633
+
634
+ def _process_valid_data(self):
635
+ inter_data = []
636
+ for uid in self.inters:
637
+ items = self.inters[uid]
638
+ one_data = dict()
639
+ one_data["item"] = str(items[-2])
640
+ one_data["title"] = self.item_feat[str(items[-2])]["title"].strip().strip(".!?,;:`")
641
+ one_data["description"] = self.item_feat[str(items[-2])]["description"]
642
+ one_data["description"] = one_data["description"].replace('{','').replace('}','')
643
+
644
+ history = items[:-2]
645
+ if self.max_his_len > 0:
646
+ history = history[-self.max_his_len:]
647
+ one_data['inters'] = ''
648
+ for item in history:
649
+ one_data['inters'] = one_data['inters'] + str(item) +','
650
+ one_data['inters'] = one_data['inters'][:-1]
651
+
652
+ inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
653
+ one_data["inter_titles"] = self.his_sep.join(inter_titles)
654
+
655
+ inter_data.append(one_data)
656
+
657
+ if self.sample_num > 0:
658
+ all_inter_idx = range(len(inter_data))
659
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
660
+ inter_data = np.array(inter_data)[sample_idx].tolist()
661
+
662
+ return inter_data
663
+
664
+ def _process_test_data(self):
665
+ inter_data = []
666
+ for uid in self.inters:
667
+ items = self.inters[uid]
668
+ one_data = dict()
669
+ one_data["item"] = str(items[-1])
670
+ one_data["title"] = self.item_feat[str(items[-1])]["title"].strip().strip(".!?,;:`")
671
+ one_data["description"] = self.item_feat[str(items[-1])]["description"]
672
+
673
+ history = items[:-1]
674
+ if self.max_his_len > 0:
675
+ history = history[-self.max_his_len:]
676
+
677
+ one_data['inters'] = ''
678
+ for item in history:
679
+ one_data['inters'] = one_data['inters'] + str(item) +','
680
+ one_data['inters'] = one_data['inters'][:-1]
681
+
682
+ inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
683
+ one_data["inter_titles"] = self.his_sep.join(inter_titles)
684
+
685
+ inter_data.append(one_data)
686
+
687
+ if self.sample_num > 0:
688
+ all_inter_idx = range(len(inter_data))
689
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
690
+ inter_data = np.array(inter_data)[sample_idx].tolist()
691
+
692
+ return inter_data
693
+
694
+ def set_prompt(self, prompt_id):
695
+ self.prompt_id = prompt_id
696
+
697
+ def __len__(self):
698
+ if self.mode == 'train':
699
+ return len(self.inter_data) * self.prompt_sample_num
700
+ elif self.mode == 'valid':
701
+ return len(self.valid_text_data)
702
+ elif self.mode == 'test':
703
+ return len(self.inter_data)
704
+ else:
705
+ raise NotImplementedError
706
+
707
+ def _construct_valid_text(self):
708
+ self.valid_text_data = []
709
+ if self.sample_valid:
710
+ all_prompt_ids = range(len(self.prompts))
711
+ for i in range(len(self.inter_data)):
712
+ d = self.inter_data[i]
713
+ prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
714
+ if self.task == 'inters2title':
715
+ for prompt_id in prompt_ids:
716
+ prompt = self.prompts[prompt_id]
717
+ input = sft_prompt.format(instruction = prompt['instruction'], response = "")
718
+ response = prompt['response'].format(title = d['title'])
719
+ output = sft_prompt.format(instruction = prompt['instruction'], response = response)
720
+ self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': d['inters'], 'item': 'placeholder', 'task': self.task})
721
+ elif self.task == 'inters2description':
722
+ for prompt_id in prompt_ids:
723
+ prompt = self.prompts[prompt_id]
724
+ input = sft_prompt.format(instruction = prompt['instruction'], response = "")
725
+ response = prompt['response'].format(title = d['description'])
726
+ output = sft_prompt.format(instruction = prompt['instruction'], response = response)
727
+ self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': d['inters'], 'item': 'placeholder', 'task': self.task})
728
+ elif self.task == 'intertitles2item':
729
+ for prompt_id in prompt_ids:
730
+ prompt = self.prompts[prompt_id]
731
+ instruction = prompt['instruction'].format(inter_titles = d['inter_titles'])
732
+ input = sft_prompt.format(instruction = instruction, response = "")
733
+ output = sft_prompt.format(instruction = instruction, response = prompt["response"])
734
+ self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': 'placeholder', 'item': d['item'], 'task': self.task})
735
+ else:
736
+ raise NotImplementedError
737
+ else:
738
+ self.prompt_sample_num = 1
739
+ prompt = self.prompts[self.valid_prompt_id]
740
+ for i in range(len(self.inter_data)):
741
+ d = self.inter_data[i]
742
+ if self.task == 'inters2title':
743
+ input = sft_prompt.format(instruction = prompt['instruction'], response = "")
744
+ response = prompt['response'].format(title = d['title'])
745
+ output = sft_prompt.format(instruction = prompt['instruction'], response = response)
746
+ self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': d['inters'], 'item': 'placeholder', 'task': self.task})
747
+ elif self.task == 'inters2description':
748
+ input = sft_prompt.format(instruction = prompt['instruction'], response = "")
749
+ response = prompt['response'].format(title = d['description'])
750
+ output = sft_prompt.format(instruction = prompt['instruction'], response = response)
751
+ self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': d['inters'], 'item': 'placeholder', 'task': self.task})
752
+ elif self.task == 'intertitles2item':
753
+ instruction = prompt['instruction'].format(inter_titles = d['inter_titles'])
754
+ input = sft_prompt.format(instruction = instruction, response = "")
755
+ output = sft_prompt.format(instruction = instruction, response = prompt["response"])
756
+ self.valid_text_data.append({"input_ids": input, "labels": output, 'inters': 'placeholder', 'item': d['item'], 'task': self.task})
757
+ else:
758
+ raise NotImplementedError
759
+
760
+ def _get_text_data(self, data, prompt):
761
+ instruction = prompt["instruction"].format(**data)
762
+ response = prompt["response"].format(**data)
763
+
764
+ input = sft_prompt.format(instruction=instruction, response="")
765
+ output = sft_prompt.format(instruction=instruction, response=response)
766
+
767
+ if self.mode == 'test':
768
+ return input, response
769
+
770
+ return input, output
771
+
772
+ def __getitem__(self, index):
773
+ if self.mode == 'valid':
774
+ return self.valid_text_data[index]
775
+
776
+ idx = index // self.prompt_sample_num
777
+ d = self.inter_data[idx]
778
+
779
+ if self.mode == 'train':
780
+ prompt_id = random.randint(0, len(self.prompts) - 1)
781
+ elif self.mode == 'test':
782
+ prompt_id = self.prompt_id
783
+
784
+ prompt = self.prompts[prompt_id]
785
+
786
+ if self.task == 'inters2title':
787
+ input = sft_prompt.format(instruction = prompt['instruction'], response = "")
788
+ response = prompt['response'].format(title = d['title'])
789
+ output = sft_prompt.format(instruction = prompt['instruction'], response = response)
790
+ return dict(input_ids = input, labels = output, inters = d['inters'], user = d['user'], item = 'placeholder', task = self.task, users = 'placeholder')
791
+ elif self.task == 'inters2description':
792
+ input = sft_prompt.format(instruction = prompt['instruction'], response = "")
793
+ response = prompt['response'].format(description = d['description'])
794
+ output = sft_prompt.format(instruction = prompt['instruction'], response = response)
795
+ return dict(input_ids = input, labels = output, inters = d['inters'], user = d['user'], item = 'placeholder', task = self.task, users = 'placeholder')
796
+ elif self.task == 'intertitles2item':
797
+ instruction = prompt['instruction'].format(user = '{user}', inter_titles = d['inter_titles'])
798
+ input = sft_prompt.format(instruction = instruction, response = "")
799
+ output = sft_prompt.format(instruction = instruction, response = prompt["response"])
800
+ return dict(input_ids = input, labels = output, inters = 'placeholder', user = d['user'], item = d['item'], task = self.task, users = 'placeholder')
801
+ else:
802
+ raise NotImplementedError
803
+
804
+ # preferenceobtain
805
+ class PreferenceObtainDataset(BaseDataset):
806
+ def __init__(self, args, prompt_sample_num=1, sample_num=-1):
807
+ super().__init__(args)
808
+
809
+ self.prompt_sample_num = prompt_sample_num
810
+ self.sample_num = sample_num
811
+
812
+ self.prompts = all_prompt["preferenceobtain"]
813
+
814
+ # load data
815
+ self._load_data()
816
+
817
+ self.preference_data = self._process_data()
818
+
819
+ def _load_data(self):
820
+ with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
821
+ self.user_info = json.load(f)
822
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
823
+ self.inters = json.load(f)
824
+
825
+ def _process_data(self):
826
+ preference_data = []
827
+ user_explicit_preference = self.user_info["user_explicit_preference"]
828
+
829
+ for uid in user_explicit_preference.keys():
830
+ one_data = {}
831
+ one_data['user'] = uid
832
+ inters = self.inters[uid][:-3]
833
+ user_ep = user_explicit_preference[uid]
834
+
835
+ if self.max_his_len > 0:
836
+ inters = inters[-self.max_his_len:]
837
+ one_data['inters'] = ''
838
+ for item in inters:
839
+ one_data['inters'] = one_data['inters'] + str(item) + ','
840
+ one_data['inters'] = one_data['inters'][:-1]
841
+
842
+ one_data["explicit_preferences"] = user_ep
843
+
844
+ preference_data.append(one_data)
845
+
846
+ if self.sample_num > 0:
847
+ all_idx = range(len(preference_data))
848
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
849
+ preference_data = np.array(preference_data)[sample_idx].tolist()
850
+
851
+ return preference_data
852
+
853
+ def set_prompt(self, prompt_id):
854
+ self.prompt_id = prompt_id
855
+
856
+ def __len__(self):
857
+ return len(self.preference_data) * self.prompt_sample_num
858
+
859
+ def _get_text_data(self, data, prompt):
860
+
861
+ instruction = prompt["instruction"].format(**data)
862
+ response = prompt["response"].format(**data)
863
+
864
+ input = sft_prompt.format(instruction = instruction, response = "")
865
+ output = sft_prompt.format(instruction = instruction, response = response)
866
+
867
+ return input, output
868
+
869
+ def __getitem__(self, index):
870
+
871
+ idx = index // self.prompt_sample_num
872
+
873
+ d = self.preference_data[idx]
874
+ prompt_id = random.randint(0, len(self.prompts) - 1)
875
+
876
+ prompt = self.prompts[prompt_id]
877
+
878
+ d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
879
+ d["explicit_preference"] = d["explicit_preference"].replace('{','').replace('}','')
880
+
881
+ input = sft_prompt.format(instruction = prompt["instruction"], response = "")
882
+ response = prompt["response"].format(**d)
883
+ output = sft_prompt.format(instruction = prompt["instruction"], response = response)
884
+ return dict(input_ids = input, labels = output, inters = d['inters'], user = d['user'], item = 'placeholder', task = 'preferenceobtain', users = 'placeholder')
885
+
886
+ # item2index & index2item
887
+ class ItemFeatDataset(BaseDataset):
888
+ def __init__(self, args, task="item2index", prompt_sample_num=1, sample_num=-1):
889
+ super().__init__(args)
890
+
891
+ self.task = task.lower()
892
+ self.prompt_sample_num = prompt_sample_num
893
+ self.sample_num = sample_num
894
+
895
+ self.prompts = all_prompt[self.task]
896
+
897
+ self._load_data()
898
+ self.feat_data = self._process_data()
899
+
900
+ def _load_data(self):
901
+ with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
902
+ self.item_feat = json.load(f)
903
+
904
+ def _process_data(self):
905
+ feat_data = []
906
+ for iid in self.item_feat:
907
+ feat = self.item_feat[iid]
908
+ feat["item"] = iid
909
+ feat["title"] = feat["title"].strip().strip(".!?,;:`")
910
+ feat["title"] = feat["title"].replace('{','').replace('}','')
911
+ feat["description"] = feat["description"].strip().strip(".!?,;:`")
912
+ feat["description"] = feat["description"].replace('{','').replace('}','')
913
+ feat_data.append(feat)
914
+
915
+ if self.sample_num > 0:
916
+ all_idx = range(len(feat_data))
917
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
918
+ feat_data = np.array(feat_data)[sample_idx].tolist()
919
+
920
+ return feat_data
921
+
922
+ def __len__(self):
923
+ return len(self.feat_data) * self.prompt_sample_num
924
+
925
+ def _get_text_data(self, data, prompt):
926
+ instruction = prompt["instruction"].format(**data)
927
+ response = prompt["response"].format(**data)
928
+
929
+ input = sft_prompt.format(instruction = instruction, response = "")
930
+ output = sft_prompt.format(instruction = instruction, response = response)
931
+
932
+ return input, output
933
+
934
+ def __getitem__(self, index):
935
+ idx = index // self.prompt_sample_num
936
+ d = self.feat_data[idx]
937
+
938
+ prompt_id = random.randint(0, len(self.prompts) - 1)
939
+
940
+ prompt = self.prompts[prompt_id]
941
+
942
+ if self.task == 'item2index':
943
+ instruction = prompt["instruction"].format(**d)
944
+ input = sft_prompt.format(instruction = instruction, response = "")
945
+ output = sft_prompt.format(instruction = instruction, response = prompt["response"])
946
+ return dict(input_ids = input, labels = output, inters = 'placeholder', user = 'placeholder', item = d['item'], task = self.task, users = 'placeholder')
947
+ elif self.task == 'index2item':
948
+ input = sft_prompt.format(instruction = prompt["instruction"], response = "")
949
+ response = prompt["response"].format(**d)
950
+ output = sft_prompt.format(instruction = prompt["instruction"], response = response)
951
+ return dict(input_ids = input, labels = output, inters = 'placeholder', user = 'placeholder', item = d['item'], task = self.task, users = 'placeholder')
952
+ else:
953
+ raise NotImplementedError
954
+
955
+
956
+
957
+ class SeqRecTestDataset(BaseDataset):
958
+
959
+ def __init__(self, args, prompt_id=0, sample_num=-1):
960
+ super().__init__(args)
961
+
962
+ self.prompt_id = prompt_id
963
+ self.sample_num = sample_num
964
+
965
+ self.prompt = all_prompt["seqrec"][self.prompt_id]
966
+
967
+ # load data
968
+ self._load_data()
969
+ self._remap_items()
970
+
971
+ self.inter_data = self._process_test_data()
972
+
973
+ def _load_data(self):
974
+
975
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
976
+ self.inters = json.load(f)
977
+ with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
978
+ self.indices = json.load(f)
979
+
980
+
981
+ def _remap_items(self):
982
+
983
+ self.remapped_inters = dict()
984
+ for uid, items in self.inters.items():
985
+ new_items = ["".join(self.indices[str(i)]) for i in items]
986
+ self.remapped_inters[uid] = new_items
987
+
988
+ def _process_test_data(self):
989
+
990
+ inter_data = []
991
+ for uid in self.remapped_inters:
992
+ items = self.remapped_inters[uid]
993
+ one_data = dict()
994
+ # one_data["user"] = uid
995
+ one_data["item"] = items[-1]
996
+ history = items[:-1]
997
+ if self.max_his_len > 0:
998
+ history = history[-self.max_his_len:]
999
+ if self.add_prefix:
1000
+ history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
1001
+ one_data["inters"] = self.his_sep.join(history)
1002
+ inter_data.append(one_data)
1003
+
1004
+ if self.sample_num > 0:
1005
+ all_inter_idx = range(len(inter_data))
1006
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
1007
+
1008
+ inter_data = np.array(inter_data)[sample_idx].tolist()
1009
+
1010
+ return inter_data
1011
+
1012
+ def set_prompt(self, prompt_id):
1013
+ self.prompt_id = prompt_id
1014
+
1015
+ self.prompt = all_prompt["seqrec"][self.prompt_id]
1016
+
1017
+ def __len__(self):
1018
+
1019
+ return len(self.inter_data)
1020
+
1021
+ def _get_text_data(self, data, prompt):
1022
+
1023
+ instruction = prompt["instruction"].format(**data)
1024
+ response = prompt["response"].format(**data)
1025
+
1026
+ input = sft_prompt.format(instruction=instruction, response="")
1027
+
1028
+ return input, response
1029
+
1030
+ def __getitem__(self, index):
1031
+
1032
+ d = self.inter_data[index]
1033
+ input, target = self._get_text_data(d, self.prompt)
1034
+
1035
+ return dict(input_ids=input, labels=target)
data_finetune.py ADDED
@@ -0,0 +1,1026 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import argparse
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.utils.data import Dataset
8
+ from tqdm import tqdm
9
+ from collections import defaultdict
10
+ import torch.distributed as dist
11
+ import logging
12
+ import re
13
+ import pdb
14
+ import json
15
+ from prompt_finetune import sft_prompt, all_prompt
16
+ import numpy as np
17
+
18
+ class BaseDataset(Dataset):
19
+
20
+ def __init__(self, args):
21
+ super().__init__()
22
+
23
+ self.args = args
24
+ self.dataset = args.dataset
25
+ self.data_path = os.path.join(args.data_path, self.dataset)
26
+
27
+ self.max_his_len = args.max_his_len
28
+ self.his_sep = args.his_sep
29
+ self.index_file = args.index_file
30
+ self.user_index_file = args.user_index_file
31
+ self.add_prefix = args.add_prefix
32
+
33
+ self.new_tokens = None
34
+ self.allowed_tokens = None
35
+ self.all_items = None
36
+
37
+ def _load_data(self):
38
+
39
+ with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
40
+ self.indices = json.load(f)
41
+
42
+ def get_new_tokens(self):
43
+
44
+ if self.new_tokens is not None:
45
+ return self.new_tokens
46
+
47
+ self.new_tokens = set()
48
+ for index in self.indices.values():
49
+ for token in index:
50
+ self.new_tokens.add(token)
51
+ self.new_tokens = sorted(list(self.new_tokens))
52
+
53
+ return self.new_tokens
54
+
55
+ def get_all_items(self):
56
+
57
+ if self.all_items is not None:
58
+ return self.all_items
59
+
60
+ self.all_items = set()
61
+ for index in self.indices.values():
62
+ self.all_items.add("".join(index))
63
+
64
+ return self.all_items
65
+
66
+ def get_prefix_allowed_tokens_fn(self, tokenizer):
67
+
68
+ if self.allowed_tokens is None:
69
+ self.allowed_tokens = {}
70
+ for index in self.indices.values():
71
+ for i, token in enumerate(index):
72
+ token_id = tokenizer(token)["input_ids"][1]
73
+ if i not in self.allowed_tokens.keys():
74
+ self.allowed_tokens[i] = set()
75
+ self.allowed_tokens[i].add(token_id)
76
+ self.allowed_tokens[len(self.allowed_tokens.keys())] = set([tokenizer.eos_token_id])
77
+ sep = tokenizer("Response:")["input_ids"][1:]
78
+
79
+ def prefix_allowed_tokens_fn(batch_id, sentence):
80
+ sentence = sentence.tolist()
81
+ reversed_sent = sentence[::-1]
82
+ for i in range(len(reversed_sent)):
83
+ if reversed_sent[i:i + len(sep)] == sep[::-1]:
84
+ # print(list(self.allowed_tokens[i]))
85
+ return list(self.allowed_tokens[i])
86
+
87
+ return prefix_allowed_tokens_fn
88
+
89
+ def _process_data(self):
90
+
91
+ raise NotImplementedError
92
+
93
+ class UserSearchFinetune(BaseDataset):
94
+ def __init__(self, args, prompt_sample_num = 1, prompt_id = 0, sample_num = -1):
95
+ super().__init__(args)
96
+
97
+ self.prompt_sample_num = prompt_sample_num
98
+ self.prompt_id = prompt_id
99
+ self.sample_num = sample_num
100
+
101
+ self.prompts = all_prompt["usersearch"]
102
+
103
+ self._load_data()
104
+ self._remap_items()
105
+ self.search_data = self._process_data()
106
+
107
+ def _load_data(self):
108
+ with open(os.path.join(self.data_path, self.dataset + ".inter.user.json"), 'r') as f:
109
+ self.user_inters = json.load(f)
110
+ with open(self.user_index_file, 'r') as f:
111
+ self.user_indices = json.load(f)
112
+ with open(self.index_file, 'r') as f:
113
+ self.indices = json.load(f)
114
+
115
+ def _remap_items(self):
116
+ self.remapped_user_inters = dict()
117
+ for iid, users in self.user_inters.items():
118
+ new_users = ["".join(self.user_indices[str(i)]) for i in users]
119
+ self.remapped_user_inters[iid] = new_users
120
+
121
+ def _process_data(self):
122
+ search_data = []
123
+ for iid in self.remapped_user_inters.keys():
124
+ users = self.remapped_user_inters[iid]
125
+ for i in range(1, len(users)):
126
+ one_data = {}
127
+ one_data['item'] = self.indices[iid]
128
+ one_data['user'] = users[i]
129
+ history = users[:i]
130
+
131
+ if len(history) > self.max_his_len:
132
+ history = history[-self.max_his_len:]
133
+
134
+ one_data['users'] = self.his_sep.join(history)
135
+
136
+ # one_data['users'] = ''
137
+ # for user in history:
138
+ # one_data['users'] = one_data['users'] + str(user) + ','
139
+ # one_data['users'] = one_data['users'][:-1]
140
+
141
+ search_data.append(one_data)
142
+
143
+ if self.sample_num > 0:
144
+ all_idx = range(len(search_data))
145
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace = False)
146
+ search_data = np.array(search_data)[sample_idx].tolist()
147
+
148
+ return search_data
149
+
150
+ def __len__(self):
151
+ return len(self.search_data) * self.prompt_sample_num
152
+
153
+ def __getitem__(self, index):
154
+ idx = index // self.prompt_sample_num
155
+ d = self.search_data[idx]
156
+
157
+ prompt_id = random.randint(0, len(self.prompts) - 1)
158
+ prompt = self.prompts[prompt_id]
159
+
160
+ instruction = prompt["instruction"].format(**d)
161
+ response = prompt["response"].format(**d)
162
+
163
+ input = sft_prompt.format(instruction = prompt["instruction"], response = "")
164
+ output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"])
165
+
166
+ return dict(input_ids = input, labels = output)
167
+ # return dict(
168
+ # input_ids = input,
169
+ # labels = output,
170
+ # inters = 'placeholder',
171
+ # item = d['item'],
172
+ # users = d['users'],
173
+ # user = d['user'],
174
+ # task = 'usersearch'
175
+ # )
176
+
177
+ class UserFeatFinetune(BaseDataset):
178
+ def __init__(self, args, task = "pref2user", prompt_sample_num = 1, sample_num = -1):
179
+ super().__init__(args)
180
+
181
+ self.task = task.lower()
182
+ self.prompt_sample_num = prompt_sample_num
183
+ self.sample_num = sample_num
184
+
185
+ self.prompts = all_prompt[self.task]
186
+
187
+ self._load_data()
188
+ self.feat_data = self._process_data()
189
+
190
+ def _load_data(self):
191
+ with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
192
+ user_feat = json.load(f)
193
+ self.user_feat = user_feat['user_explicit_preference']
194
+ with open(self.user_index_file, 'r') as f:
195
+ self.user_indices = json.load(f)
196
+
197
+ def _process_data(self):
198
+ feat_data = []
199
+ for uid in self.user_feat:
200
+ one_data = {}
201
+ one_data['user'] = self.user_indices[uid]
202
+
203
+ preference = " ".join(self.user_feat[uid])
204
+ preference = preference.strip().strip(".!?,;:`")
205
+ preference = preference.replace('{','').replace('}','')
206
+ one_data['preference'] = preference
207
+
208
+ feat_data.append(one_data)
209
+
210
+ if self.sample_num > 0:
211
+ all_idx = range(len(feat_data))
212
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace = False)
213
+ feat_data = np.array(feat_data)[sample_idx].tolist()
214
+
215
+ return feat_data
216
+
217
+ def __len__(self):
218
+ return len(self.feat_data) * self.prompt_sample_num
219
+
220
+ def __getitem__(self, index):
221
+ idx = index // self.prompt_sample_num
222
+ d = self.feat_data[idx]
223
+ prompt_id = random.randint(0, len(self.prompts) - 1)
224
+ prompt = self.prompts[prompt_id]
225
+
226
+ instruction = prompt["instruction"].format(**d)
227
+ response = prompt["response"].format(**d)
228
+
229
+ input = sft_prompt.format(instruction = prompt["instruction"], response = "")
230
+ output = sft_prompt.format(instruction = prompt["instruction"], response = prompt["response"])
231
+
232
+ return dict(input_ids = input, labels = output)
233
+
234
+ # if self.task == 'pref2user':
235
+ # instruction = prompt['instruction'].format(preference = d['preference'])
236
+ # input = sft_prompt.format(instruction = instruction, response = "")
237
+ # output = sft_prompt.format(instruction = instruction, response = prompt["response"])
238
+ # return dict(
239
+ # input_ids = input,
240
+ # labels = output,
241
+ # inters = 'placeholder',
242
+ # item = 'placeholder',
243
+ # users = 'placeholder',
244
+ # user = d['user'],
245
+ # task = self.task
246
+ # )
247
+ # elif self.task == 'user2pref':
248
+ # input = sft_prompt.format(instruction = prompt["instruction"], response = "")
249
+ # response = prompt["response"].format(preference = d['preference'])
250
+ # output = sft_prompt.format(instruction = prompt["instruction"], response = response)
251
+ # return dict(
252
+ # input_ids = input,
253
+ # labels = output,
254
+ # inters = 'placeholder',
255
+ # item = 'placeholder',
256
+ # users = 'placeholder',
257
+ # user = d['user'],
258
+ # task = self.task
259
+ # )
260
+ # else:
261
+ # raise NotImplementedError
262
+
263
+ class SeqRecFinetune(BaseDataset):
264
+
265
+ def __init__(self, args, mode="train",
266
+ prompt_sample_num=1, prompt_id=0, sample_num=-1):
267
+ super().__init__(args)
268
+
269
+ self.mode = mode
270
+ self.prompt_sample_num = prompt_sample_num
271
+ self.prompt_id = prompt_id
272
+ self.sample_num = sample_num
273
+
274
+ self.prompts = all_prompt["seqrec"]
275
+
276
+ # load data
277
+ self._load_data()
278
+ self._remap_items()
279
+
280
+ # load data
281
+ if self.mode == 'train':
282
+ self.inter_data = self._process_train_data()
283
+ elif self.mode == 'valid':
284
+ self.sample_valid = args.sample_valid
285
+ self.valid_prompt_id = args.valid_prompt_id
286
+ self.inter_data = self._process_valid_data()
287
+ self._construct_valid_text()
288
+ elif self.mode == 'test':
289
+ self.inter_data = self._process_test_data()
290
+ else:
291
+ raise NotImplementedError
292
+
293
+ def _load_data(self):
294
+
295
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
296
+ self.inters = json.load(f)
297
+ with open(self.index_file, 'r') as f:
298
+ self.indices = json.load(f)
299
+ with open(self.user_index_file, 'r') as f:
300
+ self.user_indices = json.load(f)
301
+
302
+ def _remap_items(self):
303
+
304
+ self.remapped_inters = dict()
305
+ for uid, items in self.inters.items():
306
+ new_items = ["".join(self.indices[str(i)]) for i in items]
307
+ self.remapped_inters[uid] = new_items
308
+
309
+ def _process_train_data(self):
310
+
311
+ inter_data = []
312
+ for uid in self.remapped_inters:
313
+ items = self.remapped_inters[uid][:-2]
314
+ for i in range(1, len(items)):
315
+ one_data = dict()
316
+ one_data["user"] = self.user_indices[uid]
317
+ one_data["item"] = items[i]
318
+ history = items[:i]
319
+ if self.max_his_len > 0:
320
+ history = history[-self.max_his_len:]
321
+ if self.add_prefix:
322
+ history = [str(k+1) + ". " + item_idx for k, item_idx in enumerate(history)]
323
+ one_data["inters"] = self.his_sep.join(history)
324
+ inter_data.append(one_data)
325
+
326
+ return inter_data
327
+
328
+ def _process_valid_data(self):
329
+
330
+ inter_data = []
331
+ for uid in self.remapped_inters:
332
+ items = self.remapped_inters[uid]
333
+ one_data = dict()
334
+ # one_data["user"] = uid
335
+ one_data["user"] = self.user_indices[uid]
336
+ one_data["item"] = items[-2]
337
+ history = items[:-2]
338
+ if self.max_his_len > 0:
339
+ history = history[-self.max_his_len:]
340
+ if self.add_prefix:
341
+ history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
342
+ one_data["inters"] = self.his_sep.join(history)
343
+ inter_data.append(one_data)
344
+
345
+ return inter_data
346
+
347
+ def _process_test_data(self):
348
+
349
+ inter_data = []
350
+ for uid in self.remapped_inters:
351
+ items = self.remapped_inters[uid]
352
+ one_data = dict()
353
+ # one_data["user"] = uid
354
+ one_data["user"] = self.user_indices[uid]
355
+ one_data["item"] = items[-1]
356
+ history = items[:-1]
357
+ if self.max_his_len > 0:
358
+ history = history[-self.max_his_len:]
359
+ if self.add_prefix:
360
+ history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
361
+ one_data["inters"] = self.his_sep.join(history)
362
+ inter_data.append(one_data)
363
+
364
+ if self.sample_num > 0:
365
+ all_inter_idx = range(len(inter_data))
366
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
367
+ inter_data = np.array(inter_data)[sample_idx].tolist()
368
+
369
+ return inter_data
370
+
371
+ def set_prompt(self, prompt_id):
372
+
373
+ self.prompt_id = prompt_id
374
+
375
+ def __len__(self):
376
+ if self.mode == 'train':
377
+ return len(self.inter_data) * self.prompt_sample_num
378
+ elif self.mode == 'valid':
379
+ return len(self.valid_text_data)
380
+ elif self.mode == 'test':
381
+ return len(self.inter_data)
382
+ else:
383
+ raise NotImplementedError
384
+
385
+ def _construct_valid_text(self):
386
+ self.valid_text_data = []
387
+ if self.sample_valid:
388
+ all_prompt_ids = range(len(self.prompts))
389
+ for i in range(len(self.inter_data)):
390
+ d = self.inter_data[i]
391
+ prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
392
+ for prompt_id in prompt_ids:
393
+ prompt = self.prompts[prompt_id]
394
+ input, output = self._get_text_data(d, prompt)
395
+ self.valid_text_data.append({"input_ids": input, "labels": output})
396
+ else:
397
+ self.prompt_sample_num = 1
398
+ prompt = self.prompts[self.valid_prompt_id]
399
+ for i in range(len(self.inter_data)):
400
+ d = self.inter_data[i]
401
+ input, output = self._get_text_data(d, prompt)
402
+ self.valid_text_data.append({"input_ids": input, "labels": output})
403
+
404
+ def _get_text_data(self, data, prompt):
405
+
406
+ instruction = prompt["instruction"].format(**data)
407
+ response = prompt["response"].format(**data)
408
+
409
+ input = sft_prompt.format(instruction = instruction, response = "")
410
+ output = sft_prompt.format(instruction = instruction, response = response)
411
+
412
+ if self.mode == 'test':
413
+ return input, response
414
+
415
+ return input, output
416
+
417
+ def __getitem__(self, index):
418
+
419
+ if self.mode == 'valid':
420
+ return self.valid_text_data[index]
421
+
422
+ idx = index // self.prompt_sample_num
423
+ d = self.inter_data[idx]
424
+ # print(index, idx)
425
+
426
+ if self.mode == 'train':
427
+ prompt_id = random.randint(0, len(self.prompts) - 1)
428
+ elif self.mode == 'test':
429
+ prompt_id = self.prompt_id
430
+
431
+ prompt = self.prompts[prompt_id]
432
+
433
+ input, output = self._get_text_data(d, prompt)
434
+
435
+ # print({"input": input, "output": output})
436
+
437
+ return dict(input_ids=input, labels=output)
438
+
439
+
440
+ class FusionSeqRecFinetune(BaseDataset):
441
+
442
+ def __init__(self, args, mode="train",
443
+ prompt_sample_num=1, prompt_id=0, sample_num=-1):
444
+ super().__init__(args)
445
+
446
+ self.mode = mode
447
+ self.prompt_sample_num = prompt_sample_num
448
+ self.prompt_id = prompt_id
449
+ self.sample_num = sample_num
450
+
451
+ self.prompts = all_prompt["fusionseqrec"]
452
+
453
+ # load data
454
+ self._load_data()
455
+ # self._remap_items()
456
+
457
+ # load data
458
+ if self.mode == 'train':
459
+ self.inter_data = self._process_train_data()
460
+ elif self.mode == 'valid':
461
+ self.sample_valid = args.sample_valid
462
+ self.valid_prompt_id = args.valid_prompt_id
463
+ self.inter_data = self._process_valid_data()
464
+ self._construct_valid_text()
465
+ elif self.mode == 'test':
466
+ self.inter_data = self._process_test_data()
467
+ else:
468
+ raise NotImplementedError
469
+
470
+
471
+ def _load_data(self):
472
+
473
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
474
+ self.inters = json.load(f)
475
+ with open(self.index_file, 'r') as f:
476
+ self.indices = json.load(f)
477
+ with open(self.user_index_file, 'r') as f:
478
+ self.user_indices = json.load(f)
479
+ # with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
480
+ # self.indices = json.load(f)
481
+ with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
482
+ self.item_feat = json.load(f)
483
+
484
+ def _process_train_data(self):
485
+
486
+ inter_data = []
487
+ for uid in self.inters:
488
+ items = self.inters[uid][:-2]
489
+ for i in range(1, len(items)):
490
+ one_data = dict()
491
+ # one_data["user"] = uid
492
+ one_data["user"] = self.user_indices[uid]
493
+ one_data["item"] = "".join(self.indices[str(items[i])])
494
+ one_data["title"] = self.item_feat[str(items[i])]["title"].strip().strip(".!?,;:`")
495
+ one_data["description"] = self.item_feat[str(items[i])]["description"]
496
+ history = items[:i]
497
+ if self.max_his_len > 0:
498
+ history = history[-self.max_his_len:]
499
+ inters = ["".join(self.indices[str(j)]) for j in history]
500
+ inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
501
+
502
+
503
+ if self.add_prefix:
504
+ inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
505
+ inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
506
+
507
+ one_data["inters"] = self.his_sep.join(inters)
508
+ one_data["inter_titles"] = self.his_sep.join(inter_titles)
509
+ inter_data.append(one_data)
510
+
511
+ if self.sample_num > 0:
512
+ all_inter_idx = range(len(inter_data))
513
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
514
+ inter_data = np.array(inter_data)[sample_idx].tolist()
515
+
516
+ return inter_data
517
+
518
+ def _process_valid_data(self):
519
+
520
+ inter_data = []
521
+ for uid in self.inters:
522
+ items = self.inters[uid]
523
+ one_data = dict()
524
+ one_data["item"] = "".join(self.indices[str(items[-2])])
525
+ one_data["title"] = self.item_feat[str(items[-2])]["title"].strip().strip(".!?,;:`")
526
+ one_data["description"] = self.item_feat[str(items[-2])]["description"]
527
+
528
+
529
+ history = items[:-2]
530
+ if self.max_his_len > 0:
531
+ history = history[-self.max_his_len:]
532
+ inters = ["".join(self.indices[str(j)]) for j in history]
533
+ inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
534
+
535
+ if self.add_prefix:
536
+ inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
537
+ inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
538
+
539
+ one_data["inters"] = self.his_sep.join(inters)
540
+ one_data["inter_titles"] = self.his_sep.join(inter_titles)
541
+ inter_data.append(one_data)
542
+
543
+ if self.sample_num > 0:
544
+ all_inter_idx = range(len(inter_data))
545
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
546
+ inter_data = np.array(inter_data)[sample_idx].tolist()
547
+
548
+ return inter_data
549
+
550
+ def _process_test_data(self):
551
+
552
+ inter_data = []
553
+ for uid in self.inters:
554
+ items = self.inters[uid]
555
+ one_data = dict()
556
+ one_data["item"] = "".join(self.indices[str(items[-1])])
557
+ one_data["title"] = self.item_feat[str(items[-1])]["title"].strip().strip(".!?,;:`")
558
+ one_data["description"] = self.item_feat[str(items[-1])]["description"]
559
+
560
+ history = items[:-1]
561
+ if self.max_his_len > 0:
562
+ history = history[-self.max_his_len:]
563
+ inters = ["".join(self.indices[str(j)]) for j in history]
564
+ inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
565
+
566
+ if self.add_prefix:
567
+ inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
568
+ inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
569
+
570
+ one_data["inters"] = self.his_sep.join(inters)
571
+ one_data["inter_titles"] = self.his_sep.join(inter_titles)
572
+ inter_data.append(one_data)
573
+
574
+ if self.sample_num > 0:
575
+ all_inter_idx = range(len(inter_data))
576
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
577
+ inter_data = np.array(inter_data)[sample_idx].tolist()
578
+
579
+ return inter_data
580
+
581
+ def set_prompt(self, prompt_id):
582
+
583
+ self.prompt_id = prompt_id
584
+
585
+ def __len__(self):
586
+ if self.mode == 'train':
587
+ return len(self.inter_data) * self.prompt_sample_num
588
+ elif self.mode == 'valid':
589
+ return len(self.valid_text_data)
590
+ elif self.mode == 'test':
591
+ return len(self.inter_data)
592
+ else:
593
+ raise NotImplementedError
594
+
595
+ def _construct_valid_text(self):
596
+ self.valid_text_data = []
597
+ if self.sample_valid:
598
+ all_prompt_ids = range(len(self.prompts))
599
+ for i in range(len(self.inter_data)):
600
+ d = self.inter_data[i]
601
+ prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
602
+ for prompt_id in prompt_ids:
603
+ prompt = self.prompts[prompt_id]
604
+ input, output = self._get_text_data(d, prompt)
605
+ self.valid_text_data.append({"input_ids": input, "labels": output})
606
+ else:
607
+ self.prompt_sample_num = 1
608
+ prompt = self.prompts[self.valid_prompt_id]
609
+ for i in range(len(self.inter_data)):
610
+ d = self.inter_data[i]
611
+ input, output = self._get_text_data(d, prompt)
612
+ self.valid_text_data.append({"input_ids": input, "labels": output})
613
+
614
+ def _get_text_data(self, data, prompt):
615
+
616
+ instruction = prompt["instruction"].format(**data)
617
+ response = prompt["response"].format(**data)
618
+
619
+ input = sft_prompt.format(instruction=instruction, response="")
620
+ output = sft_prompt.format(instruction=instruction, response=response)
621
+
622
+ if self.mode == 'test':
623
+ return input, response
624
+
625
+ return input, output
626
+
627
+ def __getitem__(self, index):
628
+
629
+ if self.mode == 'valid':
630
+ return self.valid_text_data[index]
631
+
632
+ idx = index // self.prompt_sample_num
633
+ d = self.inter_data[idx]
634
+
635
+ if self.mode == 'train':
636
+ prompt_id = random.randint(0, len(self.prompts) - 1)
637
+ elif self.mode == 'test':
638
+ prompt_id = self.prompt_id
639
+
640
+ prompt = self.prompts[prompt_id]
641
+
642
+ input, output = self._get_text_data(d, prompt)
643
+
644
+
645
+ return dict(input_ids=input, labels=output)
646
+
647
+
648
+ class ItemFeatFinetune(BaseDataset):
649
+
650
+ def __init__(self, args, task="item2index", prompt_sample_num=1, sample_num=-1):
651
+ super().__init__(args)
652
+
653
+ self.task = task.lower()
654
+ self.prompt_sample_num = prompt_sample_num
655
+ self.sample_num = sample_num
656
+
657
+ self.prompts = all_prompt[self.task]
658
+
659
+ # load data
660
+ self._load_data()
661
+ self.feat_data = self._process_data()
662
+
663
+
664
+
665
+ def _load_data(self):
666
+
667
+ # with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
668
+ # self.indices = json.load(f)
669
+ with open(self.index_file, 'r') as f:
670
+ self.indices = json.load(f)
671
+ with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
672
+ self.item_feat = json.load(f)
673
+
674
+
675
+ def _process_data(self):
676
+
677
+ feat_data = []
678
+ for iid in self.item_feat:
679
+ feat = self.item_feat[iid]
680
+ index = "".join(self.indices[iid])
681
+ feat["item"] = index
682
+ feat["title"] = feat["title"].strip().strip(".!?,;:`")
683
+ feat_data.append(feat)
684
+
685
+ if self.sample_num > 0:
686
+ all_idx = range(len(feat_data))
687
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
688
+
689
+ feat_data = np.array(feat_data)[sample_idx].tolist()
690
+
691
+ return feat_data
692
+
693
+
694
+ def __len__(self):
695
+ return len(self.feat_data) * self.prompt_sample_num
696
+
697
+ def _get_text_data(self, data, prompt):
698
+
699
+ instruction = prompt["instruction"].format(**data)
700
+ response = prompt["response"].format(**data)
701
+
702
+ input = sft_prompt.format(instruction = instruction, response = "")
703
+ output = sft_prompt.format(instruction = instruction, response = response)
704
+
705
+ return input, output
706
+
707
+ def __getitem__(self, index):
708
+
709
+ idx = index // self.prompt_sample_num
710
+ d = self.feat_data[idx]
711
+
712
+ prompt_id = random.randint(0, len(self.prompts) - 1)
713
+
714
+ prompt = self.prompts[prompt_id]
715
+
716
+ input, output = self._get_text_data(d, prompt)
717
+
718
+ return dict(input_ids=input, labels=output)
719
+
720
+
721
+ class ItemSearchFinetune(BaseDataset):
722
+
723
+ def __init__(self, args, mode="train",
724
+ prompt_sample_num=1, prompt_id=0, sample_num=-1):
725
+ super().__init__(args)
726
+
727
+ self.mode = mode
728
+ self.prompt_sample_num = prompt_sample_num
729
+ self.prompt_id = prompt_id
730
+ self.sample_num = sample_num
731
+
732
+ self.prompts = all_prompt["itemsearch"]
733
+
734
+ # load data
735
+ self._load_data()
736
+ self.search_data = self._process_data()
737
+
738
+
739
+
740
+ def _load_data(self):
741
+
742
+ # with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
743
+ # self.indices = json.load(f)
744
+ with open(self.index_file, 'r') as f:
745
+ self.indices = json.load(f)
746
+ with open(self.user_index_file, 'r') as f:
747
+ self.user_indices = json.load(f)
748
+ with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
749
+ self.user_info = json.load(f)
750
+
751
+
752
+ def _process_data(self):
753
+
754
+ search_data = []
755
+ user_explicit_preference = self.user_info["user_explicit_preference"]
756
+ user_vague_intention = self.user_info["user_vague_intention"]
757
+ if self.mode == 'train':
758
+ user_vague_intention = user_vague_intention["train"]
759
+ elif self.mode == 'test':
760
+ user_vague_intention = user_vague_intention["test"]
761
+ else:
762
+ raise NotImplementedError
763
+
764
+ for uid in user_explicit_preference.keys():
765
+ one_data = {}
766
+ user_ep = user_explicit_preference[uid]
767
+ user_vi = user_vague_intention[uid]["querys"]
768
+ one_data["explicit_preferences"] = user_ep
769
+ one_data["user_related_intention"] = user_vi[0]
770
+ one_data["item_related_intention"] = user_vi[1]
771
+ one_data["user"] = self.user_indices[uid]
772
+
773
+ iid = user_vague_intention[uid]["item"]
774
+ inters = user_vague_intention[uid]["inters"]
775
+
776
+ index = "".join(self.indices[str(iid)])
777
+ one_data["item"] = index
778
+
779
+ if self.max_his_len > 0:
780
+ inters = inters[-self.max_his_len:]
781
+ inters = ["".join(self.indices[str(i)]) for i in inters]
782
+ if self.add_prefix:
783
+ inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
784
+
785
+ one_data["inters"] = self.his_sep.join(inters)
786
+
787
+ search_data.append(one_data)
788
+
789
+ if self.sample_num > 0:
790
+ all_idx = range(len(search_data))
791
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
792
+
793
+ search_data = np.array(search_data)[sample_idx].tolist()
794
+
795
+ return search_data
796
+
797
+ def set_prompt(self, prompt_id):
798
+ self.prompt_id = prompt_id
799
+
800
+ def __len__(self):
801
+ if self.mode == 'train':
802
+ return len(self.search_data) * self.prompt_sample_num
803
+ elif self.mode == 'test':
804
+ return len(self.search_data)
805
+ else:
806
+ return len(self.search_data)
807
+
808
+
809
+ def _get_text_data(self, data, prompt):
810
+
811
+ instruction = prompt["instruction"].format(**data)
812
+ response = prompt["response"].format(**data)
813
+
814
+ input = sft_prompt.format(instruction = instruction, response = "")
815
+ output = sft_prompt.format(instruction = instruction, response = response)
816
+
817
+ if self.mode == 'test':
818
+ return input, response
819
+
820
+ return input, output
821
+
822
+ def __getitem__(self, index):
823
+
824
+ idx = index // self.prompt_sample_num
825
+
826
+ d = self.search_data[idx]
827
+ if self.mode == 'train':
828
+ prompt_id = random.randint(0, len(self.prompts) - 1)
829
+ elif self.mode == 'test':
830
+ prompt_id = self.prompt_id
831
+
832
+ prompt = self.prompts[prompt_id]
833
+
834
+ d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
835
+ all_querys = [d["user_related_intention"], d["item_related_intention"]]
836
+ d["query"] = random.choice(all_querys)
837
+
838
+ input, output = self._get_text_data(d, prompt)
839
+
840
+ return dict(input_ids=input, labels=output)
841
+
842
+
843
+
844
+ class PreferenceObtainFinetune(BaseDataset):
845
+
846
+ def __init__(self, args, prompt_sample_num=1, sample_num=-1):
847
+ super().__init__(args)
848
+
849
+ self.prompt_sample_num = prompt_sample_num
850
+ self.sample_num = sample_num
851
+
852
+ self.prompts = all_prompt["preferenceobtain"]
853
+
854
+ # load data
855
+ self._load_data()
856
+ self._remap_items()
857
+
858
+ self.preference_data = self._process_data()
859
+
860
+
861
+
862
+ def _load_data(self):
863
+
864
+ with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
865
+ self.user_info = json.load(f)
866
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
867
+ self.inters = json.load(f)
868
+ # with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
869
+ # self.indices = json.load(f)
870
+ with open(self.index_file, 'r') as f:
871
+ self.indices = json.load(f)
872
+ with open(self.user_index_file, 'r') as f:
873
+ self.user_indices = json.load(f)
874
+
875
+
876
+ def _remap_items(self):
877
+
878
+ self.remapped_inters = dict()
879
+ for uid, items in self.inters.items():
880
+ new_items = ["".join(self.indices[str(i)]) for i in items]
881
+ self.remapped_inters[uid] = new_items
882
+
883
+ def _process_data(self):
884
+
885
+ preference_data = []
886
+ user_explicit_preference = self.user_info["user_explicit_preference"]
887
+
888
+ for uid in user_explicit_preference.keys():
889
+ one_data = {}
890
+ one_data["user"] = self.user_indices[uid]
891
+ inters = self.remapped_inters[uid][:-3]
892
+ user_ep = user_explicit_preference[uid]
893
+
894
+ if self.max_his_len > 0:
895
+ inters = inters[-self.max_his_len:]
896
+ if self.add_prefix:
897
+ inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
898
+
899
+ one_data["explicit_preferences"] = user_ep
900
+ one_data["inters"] = self.his_sep.join(inters)
901
+
902
+ preference_data.append(one_data)
903
+
904
+ if self.sample_num > 0:
905
+ all_idx = range(len(preference_data))
906
+ sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
907
+
908
+ preference_data = np.array(preference_data)[sample_idx].tolist()
909
+
910
+ return preference_data
911
+
912
+ def set_prompt(self, prompt_id):
913
+ self.prompt_id = prompt_id
914
+
915
+ def __len__(self):
916
+ return len(self.preference_data) * self.prompt_sample_num
917
+
918
+
919
+ def _get_text_data(self, data, prompt):
920
+
921
+ instruction = prompt["instruction"].format(**data)
922
+ response = prompt["response"].format(**data)
923
+
924
+ input = sft_prompt.format(instruction = instruction, response = "")
925
+ output = sft_prompt.format(instruction = instruction, response = response)
926
+
927
+ return input, output
928
+
929
+ def __getitem__(self, index):
930
+
931
+ idx = index // self.prompt_sample_num
932
+
933
+ d = self.preference_data[idx]
934
+ prompt_id = random.randint(0, len(self.prompts) - 1)
935
+
936
+ prompt = self.prompts[prompt_id]
937
+
938
+ d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
939
+
940
+ input, output = self._get_text_data(d, prompt)
941
+
942
+ return dict(input_ids=input, labels=output)
943
+
944
+
945
+
946
+
947
+
948
+ class SeqRecTestDataset(BaseDataset):
949
+
950
+ def __init__(self, args, prompt_id=0, sample_num=-1):
951
+ super().__init__(args)
952
+
953
+ self.prompt_id = prompt_id
954
+ self.sample_num = sample_num
955
+
956
+ self.prompt = all_prompt["seqrec"][self.prompt_id]
957
+
958
+ # load data
959
+ self._load_data()
960
+ self._remap_items()
961
+
962
+ self.inter_data = self._process_test_data()
963
+
964
+ def _load_data(self):
965
+
966
+ with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
967
+ self.inters = json.load(f)
968
+ with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
969
+ self.indices = json.load(f)
970
+
971
+
972
+ def _remap_items(self):
973
+
974
+ self.remapped_inters = dict()
975
+ for uid, items in self.inters.items():
976
+ new_items = ["".join(self.indices[str(i)]) for i in items]
977
+ self.remapped_inters[uid] = new_items
978
+
979
+ def _process_test_data(self):
980
+
981
+ inter_data = []
982
+ for uid in self.remapped_inters:
983
+ items = self.remapped_inters[uid]
984
+ one_data = dict()
985
+ # one_data["user"] = uid
986
+ one_data["item"] = items[-1]
987
+ history = items[:-1]
988
+ if self.max_his_len > 0:
989
+ history = history[-self.max_his_len:]
990
+ if self.add_prefix:
991
+ history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
992
+ one_data["inters"] = self.his_sep.join(history)
993
+ inter_data.append(one_data)
994
+
995
+ if self.sample_num > 0:
996
+ all_inter_idx = range(len(inter_data))
997
+ sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
998
+
999
+ inter_data = np.array(inter_data)[sample_idx].tolist()
1000
+
1001
+ return inter_data
1002
+
1003
+ def set_prompt(self, prompt_id):
1004
+ self.prompt_id = prompt_id
1005
+
1006
+ self.prompt = all_prompt["seqrec"][self.prompt_id]
1007
+
1008
+ def __len__(self):
1009
+
1010
+ return len(self.inter_data)
1011
+
1012
+ def _get_text_data(self, data, prompt):
1013
+
1014
+ instruction = prompt["instruction"].format(**data)
1015
+ response = prompt["response"].format(**data)
1016
+
1017
+ input = sft_prompt.format(instruction=instruction, response="")
1018
+
1019
+ return input, response
1020
+
1021
+ def __getitem__(self, index):
1022
+
1023
+ d = self.inter_data[index]
1024
+ input, target = self._get_text_data(d, self.prompt)
1025
+
1026
+ return dict(input_ids=input, labels=target)
data_process/amazon18_data_process.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import gzip
4
+ import html
5
+ import json
6
+ import os
7
+ import random
8
+ import re
9
+ import torch
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ from utils import check_path, clean_text, amazon18_dataset2fullname, write_json_file, write_remap_index
13
+
14
+ def load_ratings(file):
15
+ users, items, inters = set(), set(), set()
16
+ with open(file, 'r') as fp:
17
+ for line in tqdm(fp, desc='Load ratings'):
18
+ try:
19
+ item, user, rating, time = line.strip().split(',')
20
+ users.add(user)
21
+ items.add(item)
22
+ inters.add((user, item, float(rating), int(time)))
23
+ except ValueError:
24
+ print(line)
25
+ return users, items, inters
26
+
27
+
28
+ def load_meta_items(file):
29
+ items = {}
30
+ with gzip.open(file, "r") as fp:
31
+ for line in tqdm(fp, desc="Load metas"):
32
+ data = json.loads(line)
33
+ item = data["asin"]
34
+ title = clean_text(data["title"])
35
+
36
+ descriptions = data["description"]
37
+ descriptions = clean_text(descriptions)
38
+
39
+ brand = data["brand"].replace("by\n", "").strip()
40
+
41
+ categories = data["category"]
42
+ new_categories = []
43
+ for category in categories:
44
+ if "</span>" in category:
45
+ break
46
+ new_categories.append(category.strip())
47
+ categories = ",".join(new_categories).strip()
48
+
49
+ items[item] = {"title": title, "description": descriptions, "brand": brand, "categories": categories}
50
+ # print(items[item])
51
+ return items
52
+
53
+
54
+ def load_review_data(args, user2id, item2id):
55
+
56
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
57
+ review_file_path = os.path.join(args.input_path, 'Review', dataset_full_name + '.json.gz')
58
+
59
+ reviews = {}
60
+
61
+ with gzip.open(review_file_path, "r") as fp:
62
+
63
+ for line in tqdm(fp,desc='Load reviews'):
64
+ inter = json.loads(line)
65
+ try:
66
+ user = inter['reviewerID']
67
+ item = inter['asin']
68
+ if user in user2id and item in item2id:
69
+ uid = user2id[user]
70
+ iid = item2id[item]
71
+ else:
72
+ continue
73
+ if 'reviewText' in inter:
74
+ review = clean_text(inter['reviewText'])
75
+ else:
76
+ review = ''
77
+ if 'summary' in inter:
78
+ summary = clean_text(inter['summary'])
79
+ else:
80
+ summary = ''
81
+ reviews[str((uid,iid))]={"review":review, "summary":summary}
82
+
83
+ except ValueError:
84
+ print(line)
85
+
86
+ return reviews
87
+
88
+
89
+ def get_user2count(inters):
90
+ user2count = collections.defaultdict(int)
91
+ for unit in inters:
92
+ user2count[unit[0]] += 1
93
+ return user2count
94
+
95
+
96
+ def get_item2count(inters):
97
+ item2count = collections.defaultdict(int)
98
+ for unit in inters:
99
+ item2count[unit[1]] += 1
100
+ return item2count
101
+
102
+
103
+ def generate_candidates(unit2count, threshold):
104
+ cans = set()
105
+ for unit, count in unit2count.items():
106
+ if count >= threshold:
107
+ cans.add(unit)
108
+ return cans, len(unit2count) - len(cans)
109
+
110
+
111
+ def filter_inters(inters, can_items=None,
112
+ user_k_core_threshold=0, item_k_core_threshold=0):
113
+ new_inters = []
114
+
115
+ # filter by meta items
116
+ if can_items:
117
+ print('\nFiltering by meta items: ')
118
+ for unit in inters:
119
+ if unit[1] in can_items.keys():
120
+ new_inters.append(unit)
121
+ inters, new_inters = new_inters, []
122
+ print(' The number of inters: ', len(inters))
123
+
124
+ # filter by k-core
125
+ if user_k_core_threshold or item_k_core_threshold:
126
+ print('\nFiltering by k-core:')
127
+ idx = 0
128
+ user2count = get_user2count(inters)
129
+ item2count = get_item2count(inters)
130
+
131
+ while True:
132
+ new_user2count = collections.defaultdict(int)
133
+ new_item2count = collections.defaultdict(int)
134
+ users, n_filtered_users = generate_candidates( # users is set
135
+ user2count, user_k_core_threshold)
136
+ items, n_filtered_items = generate_candidates(
137
+ item2count, item_k_core_threshold)
138
+ if n_filtered_users == 0 and n_filtered_items == 0:
139
+ break
140
+ for unit in inters:
141
+ if unit[0] in users and unit[1] in items:
142
+ new_inters.append(unit)
143
+ new_user2count[unit[0]] += 1
144
+ new_item2count[unit[1]] += 1
145
+ idx += 1
146
+ inters, new_inters = new_inters, []
147
+ user2count, item2count = new_user2count, new_item2count
148
+ print(' Epoch %d The number of inters: %d, users: %d, items: %d'
149
+ % (idx, len(inters), len(user2count), len(item2count)))
150
+ return inters
151
+
152
+
153
+ def make_inters_in_order(inters):
154
+ user2inters, new_inters = collections.defaultdict(list), list()
155
+ for inter in inters:
156
+ user, item, rating, timestamp = inter
157
+ user2inters[user].append((user, item, rating, timestamp))
158
+ for user in user2inters:
159
+ user_inters = user2inters[user]
160
+ user_inters.sort(key=lambda d: d[3])
161
+ interacted_item = set()
162
+ for inter in user_inters:
163
+ if inter[1] in interacted_item: # 过滤重复交互
164
+ continue
165
+ interacted_item.add(inter[1])
166
+ new_inters.append(inter)
167
+ return new_inters
168
+
169
+
170
+ def preprocess_rating(args):
171
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
172
+
173
+ print('Process rating data: ')
174
+ print(' Dataset: ', args.dataset)
175
+
176
+ # load ratings
177
+ rating_file_path = os.path.join(args.input_path, 'Ratings', dataset_full_name + '.csv')
178
+ rating_users, rating_items, rating_inters = load_ratings(rating_file_path)
179
+
180
+ # load item IDs with meta data
181
+ meta_file_path = os.path.join(args.input_path, 'Metadata', f'meta_{dataset_full_name}.json.gz')
182
+ meta_items = load_meta_items(meta_file_path)
183
+
184
+ # 1. Filter items w/o meta data;
185
+ # 2. K-core filtering;
186
+ print('The number of raw inters: ', len(rating_inters))
187
+
188
+ rating_inters = make_inters_in_order(rating_inters)
189
+
190
+ rating_inters = filter_inters(rating_inters, can_items=meta_items,
191
+ user_k_core_threshold=args.user_k,
192
+ item_k_core_threshold=args.item_k)
193
+
194
+ # sort interactions chronologically for each user
195
+ rating_inters = make_inters_in_order(rating_inters)
196
+ print('\n')
197
+
198
+ # return: list of (user_ID, item_ID, rating, timestamp)
199
+ return rating_inters, meta_items
200
+
201
+ def convert_inters2dict(inters):
202
+ user2items = collections.defaultdict(list)
203
+ user2index, item2index = dict(), dict()
204
+ for inter in inters:
205
+ user, item, rating, timestamp = inter
206
+ if user not in user2index:
207
+ user2index[user] = len(user2index)
208
+ if item not in item2index:
209
+ item2index[item] = len(item2index)
210
+ user2items[user2index[user]].append(item2index[item])
211
+ return user2items, user2index, item2index
212
+
213
+ def generate_data(args, rating_inters):
214
+ print('Split dataset: ')
215
+ print(' Dataset: ', args.dataset)
216
+
217
+ # generate train valid temp
218
+ user2items, user2index, item2index = convert_inters2dict(rating_inters)
219
+ train_inters, valid_inters, test_inters = dict(), dict(), dict()
220
+ for u_index in range(len(user2index)):
221
+ inters = user2items[u_index]
222
+ # leave one out
223
+ train_inters[u_index] = [str(i_index) for i_index in inters[:-2]]
224
+ valid_inters[u_index] = [str(inters[-2])]
225
+ test_inters[u_index] = [str(inters[-1])]
226
+ assert len(user2items[u_index]) == len(train_inters[u_index]) + \
227
+ len(valid_inters[u_index]) + len(test_inters[u_index])
228
+ return user2items, train_inters, valid_inters, test_inters, user2index, item2index
229
+
230
+ def convert_to_atomic_files(args, train_data, valid_data, test_data):
231
+ print('Convert dataset: ')
232
+ print(' Dataset: ', args.dataset)
233
+ uid_list = list(train_data.keys())
234
+ uid_list.sort(key=lambda t: int(t))
235
+
236
+ with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.train.inter'), 'w') as file:
237
+ file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n')
238
+ for uid in uid_list:
239
+ item_seq = train_data[uid]
240
+ seq_len = len(item_seq)
241
+ for target_idx in range(1, seq_len):
242
+ target_item = item_seq[-target_idx]
243
+ seq = item_seq[:-target_idx][-50:]
244
+ file.write(f'{uid}\t{" ".join(seq)}\t{target_item}\n')
245
+
246
+ with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.valid.inter'), 'w') as file:
247
+ file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n')
248
+ for uid in uid_list:
249
+ item_seq = train_data[uid][-50:]
250
+ target_item = valid_data[uid][0]
251
+ file.write(f'{uid}\t{" ".join(item_seq)}\t{target_item}\n')
252
+
253
+ with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.test.inter'), 'w') as file:
254
+ file.write('user_id:token\titem_id_list:token_seq\titem_id:token\n')
255
+ for uid in uid_list:
256
+ item_seq = (train_data[uid] + valid_data[uid])[-50:]
257
+ target_item = test_data[uid][0]
258
+ file.write(f'{uid}\t{" ".join(item_seq)}\t{target_item}\n')
259
+
260
+ def parse_args():
261
+ parser = argparse.ArgumentParser()
262
+ parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games')
263
+ parser.add_argument('--user_k', type=int, default=5, help='user k-core filtering')
264
+ parser.add_argument('--item_k', type=int, default=5, help='item k-core filtering')
265
+ parser.add_argument('--input_path', type=str, default='')
266
+ parser.add_argument('--output_path', type=str, default='')
267
+ return parser.parse_args()
268
+
269
+
270
+ if __name__ == '__main__':
271
+ args = parse_args()
272
+
273
+ # load interactions from raw rating file
274
+ rating_inters, meta_items = preprocess_rating(args)
275
+
276
+
277
+ # split train/valid/temp
278
+ all_inters,train_inters, valid_inters, test_inters, user2index, item2index = generate_data(args, rating_inters)
279
+
280
+ check_path(os.path.join(args.output_path, args.dataset))
281
+
282
+ write_json_file(all_inters, os.path.join(args.output_path, args.dataset, f'{args.dataset}.inter.json'))
283
+ convert_to_atomic_files(args, train_inters, valid_inters, test_inters)
284
+
285
+ item2feature = collections.defaultdict(dict)
286
+ for item, item_id in item2index.items():
287
+ item2feature[item_id] = meta_items[item]
288
+
289
+ # reviews = load_review_data(args, user2index, item2index)
290
+
291
+ print("user:",len(user2index))
292
+ print("item:",len(item2index))
293
+
294
+ write_json_file(item2feature, os.path.join(args.output_path, args.dataset, f'{args.dataset}.item.json'))
295
+ # write_json_file(reviews, os.path.join(args.output_path, args.dataset, f'{args.dataset}.review.json'))
296
+
297
+
298
+ write_remap_index(user2index, os.path.join(args.output_path, args.dataset, f'{args.dataset}.user2id'))
299
+ write_remap_index(item2index, os.path.join(args.output_path, args.dataset, f'{args.dataset}.item2id'))
data_process/amazon18_recbole_data_process.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import gzip
4
+ import html
5
+ import json
6
+ import os
7
+ import random
8
+ import re
9
+ import torch
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ from utils import check_path, clean_text, amazon18_dataset2fullname,write_json_file,write_remap_index
13
+
14
+ def load_ratings(file):
15
+ users, items, inters = set(), set(), set()
16
+ with open(file, 'r') as fp:
17
+ for line in tqdm(fp, desc='Load ratings'):
18
+ try:
19
+ item, user, rating, time = line.strip().split(',')
20
+ users.add(user)
21
+ items.add(item)
22
+ inters.add((user, item, float(rating), int(time)))
23
+ except ValueError:
24
+ print(line)
25
+ return users, items, inters
26
+
27
+
28
+ def load_meta_items(file):
29
+ items = {}
30
+ # re_tag = re.compile('</?\w+[^>]*>')
31
+ with gzip.open(file, "r") as fp:
32
+ for line in tqdm(fp, desc="Load metas"):
33
+ data = json.loads(line)
34
+ item = data["asin"]
35
+ title = clean_text(data["title"])
36
+
37
+ descriptions = data["description"]
38
+ descriptions = clean_text(descriptions)
39
+ # new_descriptions = []
40
+ # for description in descriptions:
41
+ # description = re.sub(re_tag, '', description)
42
+ # new_descriptions.append(description.strip())
43
+ # descriptions = " ".join(new_descriptions).strip()
44
+
45
+ brand = data["brand"].replace("by\n", "").strip()
46
+
47
+ categories = data["category"]
48
+ new_categories = []
49
+ for category in categories:
50
+ if "</span>" in category:
51
+ break
52
+ new_categories.append(category.strip())
53
+ categories = ",".join(new_categories[1:]).strip()
54
+
55
+ items[item] = {"title": title, "description": descriptions, "brand": brand, "categories": categories}
56
+ # print(items[item])
57
+ return items
58
+
59
+
60
+ def get_user2count(inters):
61
+ user2count = collections.defaultdict(int)
62
+ for unit in inters:
63
+ user2count[unit[0]] += 1
64
+ return user2count
65
+
66
+
67
+ def get_item2count(inters):
68
+ item2count = collections.defaultdict(int)
69
+ for unit in inters:
70
+ item2count[unit[1]] += 1
71
+ return item2count
72
+
73
+
74
+ def generate_candidates(unit2count, threshold):
75
+ cans = set()
76
+ for unit, count in unit2count.items():
77
+ if count >= threshold:
78
+ cans.add(unit)
79
+ return cans, len(unit2count) - len(cans)
80
+
81
+
82
+ def filter_inters(inters, can_items=None,
83
+ user_k_core_threshold=0, item_k_core_threshold=0):
84
+ new_inters = []
85
+
86
+ # filter by meta items
87
+ if can_items:
88
+ print('\nFiltering by meta items: ')
89
+ for unit in inters:
90
+ if unit[1] in can_items.keys():
91
+ new_inters.append(unit)
92
+ inters, new_inters = new_inters, []
93
+ print(' The number of inters: ', len(inters))
94
+
95
+ # filter by k-core
96
+ if user_k_core_threshold or item_k_core_threshold:
97
+ print('\nFiltering by k-core:')
98
+ idx = 0
99
+ user2count = get_user2count(inters)
100
+ item2count = get_item2count(inters)
101
+
102
+ while True:
103
+ new_user2count = collections.defaultdict(int)
104
+ new_item2count = collections.defaultdict(int)
105
+ users, n_filtered_users = generate_candidates( # users is set
106
+ user2count, user_k_core_threshold)
107
+ items, n_filtered_items = generate_candidates(
108
+ item2count, item_k_core_threshold)
109
+ if n_filtered_users == 0 and n_filtered_items == 0:
110
+ break
111
+ for unit in inters:
112
+ if unit[0] in users and unit[1] in items:
113
+ new_inters.append(unit)
114
+ new_user2count[unit[0]] += 1
115
+ new_item2count[unit[1]] += 1
116
+ idx += 1
117
+ inters, new_inters = new_inters, []
118
+ user2count, item2count = new_user2count, new_item2count
119
+ print(' Epoch %d The number of inters: %d, users: %d, items: %d'
120
+ % (idx, len(inters), len(user2count), len(item2count)))
121
+ return inters
122
+
123
+
124
+ def make_inters_in_order(inters):
125
+ user2inters, new_inters = collections.defaultdict(list), list()
126
+ for inter in inters:
127
+ user, item, rating, timestamp = inter
128
+ user2inters[user].append((user, item, rating, timestamp))
129
+ for user in user2inters:
130
+ user_inters = user2inters[user]
131
+ user_inters.sort(key=lambda d: d[3])
132
+ interacted_item = set()
133
+ for inter in user_inters:
134
+ if inter[1] in interacted_item: # 过滤重复交互
135
+ continue
136
+ interacted_item.add(inter[1])
137
+ new_inters.append(inter)
138
+ return new_inters
139
+
140
+
141
+ def preprocess_rating(args):
142
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
143
+
144
+ print('Process rating data: ')
145
+ print(' Dataset: ', args.dataset)
146
+
147
+ # load ratings
148
+ rating_file_path = os.path.join(args.input_path, 'Ratings', dataset_full_name + '.csv')
149
+ rating_users, rating_items, rating_inters = load_ratings(rating_file_path)
150
+
151
+ # load item IDs with meta data
152
+ meta_file_path = os.path.join(args.input_path, 'Metadata', f'meta_{dataset_full_name}.json.gz')
153
+ meta_items = load_meta_items(meta_file_path)
154
+
155
+ # 1. Filter items w/o meta data;
156
+ # 2. K-core filtering;
157
+ print('The number of raw inters: ', len(rating_inters))
158
+
159
+ rating_inters = make_inters_in_order(rating_inters)
160
+
161
+ rating_inters = filter_inters(rating_inters, can_items=meta_items,
162
+ user_k_core_threshold=args.user_k,
163
+ item_k_core_threshold=args.item_k)
164
+
165
+ # sort interactions chronologically for each user
166
+ rating_inters = make_inters_in_order(rating_inters)
167
+ print('\n')
168
+
169
+ # return: list of (user_ID, item_ID, rating, timestamp)
170
+ return rating_inters, meta_items
171
+
172
+ def save_inter(args, inters):
173
+ print('Convert dataset: ')
174
+ print(' Dataset: ', args.dataset)
175
+
176
+ with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.inter'), 'w') as file:
177
+ file.write('user_id:token\titem_id:token\trating:float\ttimestamp:float\n')
178
+ for inter in inters:
179
+ user, item, rating, timestamp = inter
180
+ file.write(f'{user}\t{item}\t{rating}\t{timestamp}\n')
181
+
182
+
183
+ def save_feat(args, feat, all_items):
184
+ iid_list = list(feat.keys())
185
+ num_item = 0
186
+ with open(os.path.join(args.output_path, args.dataset, f'{args.dataset}.item'), 'w') as file:
187
+ # "title": title, "description": descriptions, "brand": brand, "categories": categories
188
+ file.write('item_id:token\ttitle:token_seq\tbrand:token\tcategories:token_seq\n')
189
+ for iid in iid_list:
190
+ if iid in all_items:
191
+ num_item += 1
192
+ title, brand, categories = feat[iid]["title"], feat[iid]["brand"], feat[iid]["categories"]
193
+ file.write(f'{iid}\t{title}\t{brand}\t{categories}\n')
194
+ print("num_item: ", num_item)
195
+
196
+
197
+ def parse_args():
198
+ parser = argparse.ArgumentParser()
199
+ parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games')
200
+ parser.add_argument('--user_k', type=int, default=5, help='user k-core filtering')
201
+ parser.add_argument('--item_k', type=int, default=5, help='item k-core filtering')
202
+ parser.add_argument('--input_path', type=str, default='')
203
+ parser.add_argument('--output_path', type=str, default='')
204
+ return parser.parse_args()
205
+
206
+
207
+ if __name__ == '__main__':
208
+ args = parse_args()
209
+
210
+ # load interactions from raw rating file
211
+ rating_inters, meta_items = preprocess_rating(args)
212
+
213
+ check_path(os.path.join(args.output_path, args.dataset))
214
+
215
+
216
+ all_items = set()
217
+ for inter in rating_inters:
218
+ user, item, rating, timestamp = inter
219
+ all_items.add(item)
220
+
221
+ print("total item: ", len(list(all_items)))
222
+
223
+ save_inter(args,rating_inters)
224
+ save_feat(args,meta_items, all_items)
225
+
226
+
data_process/amazon_text_emb.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import gzip
4
+ import html
5
+ import json
6
+ import os
7
+ import random
8
+ import re
9
+ import torch
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ from utils import *
13
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, AutoTokenizer, AutoModel
14
+
15
+
16
+ def load_data(args):
17
+
18
+ item2feature_path = os.path.join(args.root, f'{args.dataset}.item.json')
19
+ item2feature = load_json(item2feature_path)
20
+
21
+ return item2feature
22
+
23
+ def generate_text(item2feature, features):
24
+ item_text_list = []
25
+
26
+ for item in item2feature:
27
+ data = item2feature[item]
28
+ text = []
29
+ for meta_key in features:
30
+ if meta_key in data:
31
+ meta_value = clean_text(data[meta_key])
32
+ text.append(meta_value.strip())
33
+
34
+ item_text_list.append([int(item), text])
35
+
36
+ return item_text_list
37
+
38
+ def preprocess_text(args):
39
+ print('Process text data: ')
40
+ print(' Dataset: ', args.dataset)
41
+
42
+ item2feature = load_data(args)
43
+ # load item text and clean
44
+ item_text_list = generate_text(item2feature, ['title', 'description'])
45
+ # item_text_list = generate_text(item2feature, ['title'])
46
+ # return: list of (item_ID, cleaned_item_text)
47
+ return item_text_list
48
+
49
+ def generate_item_embedding(args, item_text_list, tokenizer, model, word_drop_ratio=-1):
50
+ print(f'Generate Text Embedding: ')
51
+ print(' Dataset: ', args.dataset)
52
+
53
+ items, texts = zip(*item_text_list)
54
+ order_texts = [[0]] * len(items)
55
+ for item, text in zip(items, texts):
56
+ order_texts[item] = text
57
+ for text in order_texts:
58
+ assert text != [0]
59
+
60
+ embeddings = []
61
+ start, batch_size = 0, 1
62
+ with torch.no_grad():
63
+ while start < len(order_texts):
64
+ if (start+1)%100==0:
65
+ print("==>",start+1)
66
+ field_texts = order_texts[start: start + batch_size]
67
+ # print(field_texts)
68
+ field_texts = zip(*field_texts)
69
+
70
+ field_embeddings = []
71
+ for sentences in field_texts:
72
+ sentences = list(sentences)
73
+ # print(sentences)
74
+ if word_drop_ratio > 0:
75
+ print(f'Word drop with p={word_drop_ratio}')
76
+ new_sentences = []
77
+ for sent in sentences:
78
+ new_sent = []
79
+ sent = sent.split(' ')
80
+ for wd in sent:
81
+ rd = random.random()
82
+ if rd > word_drop_ratio:
83
+ new_sent.append(wd)
84
+ new_sent = ' '.join(new_sent)
85
+ new_sentences.append(new_sent)
86
+ sentences = new_sentences
87
+ encoded_sentences = tokenizer(sentences, max_length=args.max_sent_len,
88
+ truncation=True, return_tensors='pt',padding="longest").to(args.device)
89
+ outputs = model(input_ids=encoded_sentences.input_ids,
90
+ attention_mask=encoded_sentences.attention_mask)
91
+
92
+ masked_output = outputs.last_hidden_state * encoded_sentences['attention_mask'].unsqueeze(-1)
93
+ mean_output = masked_output.sum(dim=1) / encoded_sentences['attention_mask'].sum(dim=-1, keepdim=True)
94
+ mean_output = mean_output.detach().cpu()
95
+ field_embeddings.append(mean_output)
96
+
97
+ field_mean_embedding = torch.stack(field_embeddings, dim=0).mean(dim=0)
98
+ embeddings.append(field_mean_embedding)
99
+ start += batch_size
100
+
101
+ embeddings = torch.cat(embeddings, dim=0).numpy()
102
+ print('Embeddings shape: ', embeddings.shape)
103
+
104
+ file = os.path.join(args.root, args.dataset + '.emb-' + args.plm_name + "-td" + ".npy")
105
+ np.save(file, embeddings)
106
+
107
+
108
+ def parse_args():
109
+ parser = argparse.ArgumentParser()
110
+ parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games')
111
+ parser.add_argument('--root', type=str, default="")
112
+ parser.add_argument('--gpu_id', type=int, default=2, help='ID of running GPU')
113
+ parser.add_argument('--plm_name', type=str, default='llama')
114
+ parser.add_argument('--plm_checkpoint', type=str,
115
+ default='')
116
+ parser.add_argument('--max_sent_len', type=int, default=2048)
117
+ parser.add_argument('--word_drop_ratio', type=float, default=-1, help='word drop ratio, do not drop by default')
118
+ return parser.parse_args()
119
+
120
+
121
+ if __name__ == '__main__':
122
+ args = parse_args()
123
+
124
+ args.root = os.path.join(args.root, args.dataset)
125
+
126
+ device = set_device(args.gpu_id)
127
+ args.device = device
128
+
129
+ item_text_list = preprocess_text(args)
130
+
131
+ plm_tokenizer, plm_model = load_plm(args.plm_checkpoint)
132
+ if plm_tokenizer.pad_token_id is None:
133
+ plm_tokenizer.pad_token_id = 0
134
+ plm_model = plm_model.to(device)
135
+
136
+ generate_item_embedding(args, item_text_list,plm_tokenizer,
137
+ plm_model, word_drop_ratio=args.word_drop_ratio)
138
+
139
+
data_process/amazon_user_emb.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import collections
3
+ import gzip
4
+ import html
5
+ import json
6
+ import os
7
+ import random
8
+ import re
9
+ import torch
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ from utils import *
13
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig, AutoTokenizer, AutoModel
14
+
15
+
16
+ def load_data(args):
17
+
18
+ item2feature_path = os.path.join(args.root, f'{args.dataset}.user.json')
19
+ item2feature = load_json(item2feature_path)
20
+
21
+ return item2feature
22
+
23
+ def generate_text(item2feature, features):
24
+ item_text_list = []
25
+
26
+ for item in item2feature:
27
+ data = item2feature[item]
28
+ text = []
29
+
30
+ for i in range(len(data)):
31
+ meta_value = clean_text(data[i])
32
+ text.append(meta_value.strip())
33
+
34
+ # for meta_key in features:
35
+ # if meta_key in data:
36
+ # meta_value = clean_text(data[meta_key])
37
+ # text.append(meta_value.strip())
38
+
39
+ item_text_list.append([int(item), text])
40
+
41
+ return item_text_list
42
+
43
+ def preprocess_text(args):
44
+ print('Process text data ......')
45
+ print('Dataset:', args.dataset)
46
+
47
+ item2feature = load_data(args)
48
+ item2feature = item2feature['user_explicit_preference']
49
+ # load item text and clean
50
+ item_text_list = generate_text(item2feature)
51
+ # item_text_list = generate_text(item2feature, ['user_explicit_preference'])
52
+ # item_text_list = generate_text(item2feature, ['title'])
53
+ # return: list of (item_ID, cleaned_item_text)
54
+ return item_text_list
55
+
56
+ def generate_item_embedding(args, item_text_list, tokenizer, model, word_drop_ratio=-1):
57
+ print(f'Generate Text Embedding ......')
58
+ print('Dataset:', args.dataset)
59
+
60
+ items, texts = zip(*item_text_list)
61
+ order_texts = [[0]] * len(items)
62
+ for item, text in zip(items, texts):
63
+ order_texts[item] = text
64
+ for text in order_texts:
65
+ assert text != [0]
66
+
67
+ embeddings = []
68
+ start, batch_size = 0, 1
69
+ with torch.no_grad():
70
+ while start < len(order_texts):
71
+ if (start+1) % 100 == 0:
72
+ print("==>", start + 1)
73
+ field_texts = order_texts[start: start + batch_size]
74
+ # print(field_texts)
75
+ field_texts = zip(*field_texts)
76
+
77
+ field_embeddings = []
78
+ for sentences in field_texts:
79
+ sentences = list(sentences)
80
+ # print(sentences)
81
+ if word_drop_ratio > 0:
82
+ print(f'Word drop with p={word_drop_ratio}')
83
+ new_sentences = []
84
+ for sent in sentences:
85
+ new_sent = []
86
+ sent = sent.split(' ')
87
+ for wd in sent:
88
+ rd = random.random()
89
+ if rd > word_drop_ratio:
90
+ new_sent.append(wd)
91
+ new_sent = ' '.join(new_sent)
92
+ new_sentences.append(new_sent)
93
+ sentences = new_sentences
94
+ encoded_sentences = tokenizer(sentences, max_length=args.max_sent_len,
95
+ truncation=True, return_tensors='pt',padding="longest").to(args.device)
96
+ outputs = model(input_ids=encoded_sentences.input_ids,
97
+ attention_mask=encoded_sentences.attention_mask)
98
+
99
+ masked_output = outputs.last_hidden_state * encoded_sentences['attention_mask'].unsqueeze(-1)
100
+ mean_output = masked_output.sum(dim=1) / encoded_sentences['attention_mask'].sum(dim=-1, keepdim=True)
101
+ mean_output = mean_output.detach().cpu()
102
+ field_embeddings.append(mean_output)
103
+
104
+ field_mean_embedding = torch.stack(field_embeddings, dim=0).mean(dim=0)
105
+ embeddings.append(field_mean_embedding)
106
+ start += batch_size
107
+
108
+ embeddings = torch.cat(embeddings, dim=0).numpy()
109
+ print('Embeddings shape: ', embeddings.shape)
110
+
111
+ # file = os.path.join(args.root, args.dataset + '.emb-' + args.plm_name + "-td" + ".npy")
112
+ np.save(args.save_path, embeddings)
113
+
114
+
115
+ def parse_args():
116
+ parser = argparse.ArgumentParser()
117
+ parser.add_argument('--dataset', type=str, default='Arts', help='Instruments / Arts / Games')
118
+ parser.add_argument('--root', type=str, default="")
119
+ parser.add_argument('--gpu_id', type=int, default=2, help='ID of running GPU')
120
+ parser.add_argument('--plm_name', type=str, default='llama')
121
+ parser.add_argument('--plm_checkpoint', type=str,
122
+ default='')
123
+ parser.add_argument('--max_sent_len', type=int, default=2048)
124
+ parser.add_argument('--word_drop_ratio', type=float, default=-1, help='word drop ratio, do not drop by default')
125
+ parser.add_argument('--save_path', type=str, default="")
126
+ return parser.parse_args()
127
+
128
+
129
+ if __name__ == '__main__':
130
+ args = parse_args()
131
+
132
+ args.root = os.path.join(args.root, args.dataset)
133
+
134
+ device = set_device(args.gpu_id)
135
+ args.device = device
136
+
137
+ item_text_list = preprocess_text(args)
138
+
139
+ plm_tokenizer, plm_model = load_plm(args.plm_checkpoint)
140
+ if plm_tokenizer.pad_token_id is None:
141
+ plm_tokenizer.pad_token_id = 0
142
+ plm_model = plm_model.to(device)
143
+
144
+ generate_item_embedding(args, item_text_list,plm_tokenizer,
145
+ plm_model, word_drop_ratio=args.word_drop_ratio)
data_process/get_llm_output.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import argparse
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import time
8
+ from logging import getLogger
9
+ import openai
10
+ from utils import get_res_batch, load_json, intention_prompt, preference_prompt_1, preference_prompt_2, amazon18_dataset2fullname, write_json_file
11
+ import json
12
+
13
+
14
+
15
+ def get_intention_train(args, inters, item2feature, reviews, api_info):
16
+
17
+ intention_train_output_file = os.path.join(args.root,"intention_train.json")
18
+
19
+
20
+ # Suggest modifying the prompt based on different datasets
21
+ prompt = intention_prompt
22
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
23
+ dataset_full_name = dataset_full_name.replace("_", " ").lower()
24
+ print(dataset_full_name)
25
+
26
+ prompt_list = []
27
+
28
+ inter_data = []
29
+
30
+ for (user,item_list) in inters.items():
31
+ user = int(user)
32
+ item = int(item_list[-3])
33
+ history = item_list[:-3]
34
+
35
+ inter_data.append((user,item,history))
36
+
37
+ review = reviews[str((user, item))]["review"]
38
+ item_title = item2feature[str(item)]["title"]
39
+ input_prompt = prompt.format(item_title=item_title,dataset_full_name=dataset_full_name,review=review)
40
+ prompt_list.append(input_prompt)
41
+
42
+ st = 0
43
+ with open(intention_train_output_file, mode='a') as f:
44
+
45
+ while st < len(prompt_list):
46
+ # while st < 3:
47
+ print(st)
48
+ # if st < 25631:
49
+ # st += args.batchsize
50
+ # continue
51
+
52
+
53
+ res = get_res_batch(args.model_name, prompt_list[st:st+args.batchsize], args.max_tokens, api_info)
54
+
55
+ for i, answer in enumerate(res):
56
+ user, item, history = inter_data[st+i]
57
+ # print(answer)
58
+ # print("=============")
59
+
60
+ if answer == '':
61
+ print("answer null error")
62
+ answer = "I enjoy high-quality item."
63
+
64
+ if answer.strip().count('\n') != 1:
65
+ if 'haracteristics:' in answer:
66
+ answer = answer.strip().split("The item's characteristics:")
67
+ else:
68
+ answer = answer.strip().split("The item's characteristic:")
69
+ else:
70
+ answer = answer.strip().split('\n')
71
+
72
+ if '' in answer:
73
+ answer.remove('')
74
+
75
+ if len(answer) == 1:
76
+ print(answer)
77
+ user_preference = item_character = answer[0]
78
+ elif len(answer) >= 3:
79
+ print(answer)
80
+ answer = answer[-1]
81
+ user_preference = item_character = answer
82
+ else:
83
+ user_preference, item_character = answer
84
+
85
+ if ':' in user_preference:
86
+ idx = user_preference.index(':')
87
+ user_preference = user_preference[idx+1:]
88
+ user_preference = user_preference.strip().replace('}','')
89
+ user_preference = user_preference.replace('\n','')
90
+
91
+ if ':' in item_character:
92
+ idx = item_character.index(':')
93
+ item_character = item_character[idx+1:]
94
+ item_character = item_character.strip().replace('}','')
95
+ item_character = item_character.replace('\n','')
96
+
97
+
98
+ dict = {"user":user, "item":item, "inters": history,
99
+ "user_related_intention":user_preference, "item_related_intention": item_character}
100
+
101
+ json.dump(dict, f)
102
+ f.write("\n")
103
+
104
+ st += args.batchsize
105
+
106
+ return intention_train_output_file
107
+
108
+
109
+ def get_intention_test(args, inters, item2feature, reviews, api_info):
110
+
111
+ intention_test_output_file = os.path.join(args.root,"intention_test.json")
112
+
113
+ # Suggest modifying the prompt based on different datasets
114
+ prompt = intention_prompt
115
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
116
+ dataset_full_name = dataset_full_name.replace("_", " ").lower()
117
+ print(dataset_full_name)
118
+
119
+ prompt_list = []
120
+
121
+ inter_data = []
122
+
123
+ for (user,item_list) in inters.items():
124
+ user = int(user)
125
+ item = int(item_list[-1])
126
+ history = item_list[:-1]
127
+
128
+ inter_data.append((user,item,history))
129
+
130
+ review = reviews[str((user, item))]["review"]
131
+ item_title = item2feature[str(item)]["title"]
132
+ input_prompt = prompt.format(item_title=item_title,dataset_full_name=dataset_full_name,review=review)
133
+ prompt_list.append(input_prompt)
134
+
135
+ st = 0
136
+ with open(intention_test_output_file, mode='a') as f:
137
+
138
+ while st < len(prompt_list):
139
+ # while st < 3:
140
+ print(st)
141
+ # if st < 4623:
142
+ # st += args.batchsize
143
+ # continue
144
+
145
+ res = get_res_batch(args.model_name, prompt_list[st:st+args.batchsize], args.max_tokens, api_info)
146
+
147
+ for i, answer in enumerate(res):
148
+ user, item, history = inter_data[st+i]
149
+
150
+ if answer == '':
151
+ print("answer null error")
152
+ answer = "I enjoy high-quality item."
153
+
154
+ if answer.strip().count('\n') != 1:
155
+ if 'haracteristics:' in answer:
156
+ answer = answer.strip().split("The item's characteristics:")
157
+ else:
158
+ answer = answer.strip().split("The item's characteristic:")
159
+ else:
160
+ answer = answer.strip().split('\n')
161
+
162
+ if '' in answer:
163
+ answer.remove('')
164
+
165
+ if len(answer) == 1:
166
+ print(answer)
167
+ user_preference = item_character = answer[0]
168
+ elif len(answer) >= 3:
169
+ print(answer)
170
+ answer = answer[-1]
171
+ user_preference = item_character = answer
172
+ else:
173
+ user_preference, item_character = answer
174
+
175
+ if ':' in user_preference:
176
+ idx = user_preference.index(':')
177
+ user_preference = user_preference[idx+1:]
178
+ user_preference = user_preference.strip().replace('}','')
179
+ user_preference = user_preference.replace('\n','')
180
+
181
+ if ':' in item_character:
182
+ idx = item_character.index(':')
183
+ item_character = item_character[idx+1:]
184
+ item_character = item_character.strip().replace('}','')
185
+ item_character = item_character.replace('\n','')
186
+
187
+
188
+ dict = {"user":user, "item":item, "inters": history,
189
+ "user_related_intention":user_preference, "item_related_intention": item_character}
190
+
191
+ json.dump(dict, f)
192
+ f.write("\n")
193
+
194
+ st += args.batchsize
195
+
196
+ return intention_test_output_file
197
+
198
+
199
+
200
+
201
+ def get_user_preference(args, inters, item2feature, reviews, api_info):
202
+
203
+ preference_output_file = os.path.join(args.root,"user_preference.json")
204
+
205
+
206
+ # Suggest modifying the prompt based on different datasets
207
+ prompt_1 = preference_prompt_1
208
+ prompt_2 = preference_prompt_2
209
+
210
+
211
+ dataset_full_name = amazon18_dataset2fullname[args.dataset]
212
+ dataset_full_name = dataset_full_name.replace("_", " ").lower()
213
+ print(dataset_full_name)
214
+
215
+ prompt_list_1 = []
216
+ prompt_list_2 = []
217
+
218
+ users = []
219
+
220
+ for (user,item_list) in inters.items():
221
+ users.append(user)
222
+ history = item_list[:-3]
223
+ item_titles = []
224
+ for j, item in enumerate(history):
225
+ item_titles.append(str(j+1) + '.' + item2feature[str(item)]["title"])
226
+ if len(item_titles) > args.max_his_len:
227
+ item_titles = item_titles[-args.max_his_len:]
228
+ item_titles = ", ".join(item_titles)
229
+
230
+ input_prompt_1 = prompt_1.format(dataset_full_name=dataset_full_name, item_titles=item_titles)
231
+ input_prompt_2 = prompt_2.format(dataset_full_name=dataset_full_name, item_titles=item_titles)
232
+
233
+ prompt_list_1.append(input_prompt_1)
234
+ prompt_list_2.append(input_prompt_2)
235
+
236
+
237
+ st = 0
238
+ with open(preference_output_file, mode='a') as f:
239
+
240
+ while st < len(prompt_list_1):
241
+ # while st < 3:
242
+ print(st)
243
+ # if st < 22895:
244
+ # st += args.batchsize
245
+ # continue
246
+
247
+ res_1 = get_res_batch(args.model_name, prompt_list_1[st:st + args.batchsize], args.max_tokens, api_info)
248
+ res_2 = get_res_batch(args.model_name, prompt_list_2[st:st + args.batchsize], args.max_tokens, api_info)
249
+ for i, answers in enumerate(zip(res_1, res_2)):
250
+
251
+ user = users[st + i]
252
+
253
+ answer_1, answer_2 = answers
254
+ # print(answers)
255
+ # print("=============")
256
+
257
+ if answer_1 == '':
258
+ print("answer null error")
259
+ answer_1 = "I enjoy high-quality item."
260
+
261
+ if answer_2 == '':
262
+ print("answer null error")
263
+ answer_2 = "I enjoy high-quality item."
264
+
265
+ if answer_2.strip().count('\n') != 1:
266
+ if 'references:' in answer_2:
267
+ answer_2 = answer_2.strip().split("Short-term preferences:")
268
+ else:
269
+ answer_2 = answer_2.strip().split("Short-term preference:")
270
+ else:
271
+ answer_2 = answer_2.strip().split('\n')
272
+
273
+ if '' in answer_2:
274
+ answer_2.remove('')
275
+
276
+ if len(answer_2) == 1:
277
+ print(answer_2)
278
+ long_preference = short_preference = answer_2[0]
279
+ elif len(answer_2) >= 3:
280
+ print(answer_2)
281
+ answer_2 = answer_2[-1]
282
+ long_preference = short_preference = answer_2
283
+ else:
284
+ long_preference, short_preference = answer_2
285
+
286
+ if ':' in long_preference:
287
+ idx = long_preference.index(':')
288
+ long_preference = long_preference[idx+1:]
289
+ long_preference = long_preference.strip().replace('}','')
290
+ long_preference = long_preference.replace('\n','')
291
+
292
+ if ':' in short_preference:
293
+ idx = short_preference.index(':')
294
+ short_preference = short_preference[idx+1:]
295
+ short_preference = short_preference.strip().replace('}','')
296
+ short_preference = short_preference.replace('\n','')
297
+
298
+ dict = {"user":user,"user_preference":[answer_1, long_preference, short_preference]}
299
+ # print(dict)
300
+ json.dump(dict, f)
301
+ f.write("\n")
302
+
303
+ st += args.batchsize
304
+
305
+ return preference_output_file
306
+
307
+ def parse_args():
308
+ parser = argparse.ArgumentParser()
309
+ parser.add_argument('--dataset', type=str, default='Instruments', help='Instruments / Arts / Games')
310
+ parser.add_argument('--root', type=str, default='')
311
+ parser.add_argument('--api_info', type=str, default='./api_info.json')
312
+ parser.add_argument('--model_name', type=str, default='text-davinci-003')
313
+ parser.add_argument('--max_tokens', type=int, default=512)
314
+ parser.add_argument('--batchsize', type=int, default=16)
315
+ parser.add_argument('--max_his_len', type=int, default=20)
316
+ return parser.parse_args()
317
+
318
+ if __name__ == "__main__":
319
+ args = parse_args()
320
+
321
+ args.root = os.path.join(args.root, args.dataset)
322
+
323
+ api_info = load_json(args.api_info)
324
+ openai.api_key = api_info["api_key_list"].pop()
325
+
326
+
327
+ inter_path = os.path.join(args.root, f'{args.dataset}.inter.json')
328
+ inters = load_json(inter_path)
329
+
330
+
331
+ item2feature_path = os.path.join(args.root, f'{args.dataset}.item.json')
332
+ item2feature = load_json(item2feature_path)
333
+
334
+ reviews_path = os.path.join(args.root, f'{args.dataset}.review.json')
335
+ reviews = load_json(reviews_path)
336
+
337
+ intention_train_output_file = get_intention_train(args, inters, item2feature, reviews, api_info)
338
+ intention_test_output_file = get_intention_test(args, inters, item2feature, reviews ,api_info)
339
+ preference_output_file = get_user_preference(args, inters, item2feature, reviews, api_info)
340
+
341
+ intention_train = {}
342
+ intention_test = {}
343
+ user_preference = {}
344
+
345
+ with open(intention_train_output_file, "r") as f:
346
+ for line in f:
347
+ # print(line)
348
+ content = json.loads(line)
349
+ if content["user"] not in intention_train:
350
+ intention_train[content["user"]] = {"item":content["item"],
351
+ "inters":content["inters"],
352
+ "querys":[ content["user_related_intention"], content["item_related_intention"] ]}
353
+
354
+
355
+ with open(intention_test_output_file, "r") as f:
356
+ for line in f:
357
+ content = json.loads(line)
358
+ if content["user"] not in intention_train:
359
+ intention_test[content["user"]] = {"item":content["item"],
360
+ "inters":content["inters"],
361
+ "querys":[ content["user_related_intention"], content["item_related_intention"] ]}
362
+
363
+
364
+ with open(preference_output_file, "r") as f:
365
+ for line in f:
366
+ content = json.loads(line)
367
+ user_preference[content["user"]] = content["user_preference"]
368
+
369
+ user_dict = {
370
+ "user_explicit_preference": user_preference,
371
+ "user_vague_intention": {"train": intention_train, "test": intention_test},
372
+ }
373
+
374
+ write_json_file(user_dict, os.path.join(args.root, f'{args.dataset}.user.json'))
data_process/utils.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+ import json
3
+ import os
4
+ import pickle
5
+ import re
6
+ import time
7
+
8
+ import torch
9
+ # import gensim
10
+ from transformers import AutoModel, AutoTokenizer
11
+ import collections
12
+ import openai
13
+
14
+
15
+
16
+ def get_res_batch(model_name, prompt_list, max_tokens, api_info):
17
+
18
+ while True:
19
+ try:
20
+ res = openai.Completion.create(
21
+ model=model_name,
22
+ prompt=prompt_list,
23
+ temperature=0.4,
24
+ max_tokens=max_tokens,
25
+ top_p=1,
26
+ frequency_penalty=0,
27
+ presence_penalty=0
28
+ )
29
+ output_list = []
30
+ for choice in res['choices']:
31
+ output = choice['text'].strip()
32
+ output_list.append(output)
33
+
34
+ return output_list
35
+
36
+ except openai.error.AuthenticationError as e:
37
+ print(e)
38
+ openai.api_key = api_info["api_key_list"].pop()
39
+ time.sleep(10)
40
+ except openai.error.RateLimitError as e:
41
+ print(e)
42
+ if str(e) == "You exceeded your current quota, please check your plan and billing details.":
43
+ openai.api_key = api_info["api_key_list"].pop()
44
+ time.sleep(10)
45
+ else:
46
+ print('\nopenai.error.RateLimitError\nRetrying...')
47
+ time.sleep(10)
48
+ except openai.error.ServiceUnavailableError as e:
49
+ print(e)
50
+ print('\nopenai.error.ServiceUnavailableError\nRetrying...')
51
+ time.sleep(10)
52
+ except openai.error.Timeout:
53
+ print('\nopenai.error.Timeout\nRetrying...')
54
+ time.sleep(10)
55
+ except openai.error.APIError as e:
56
+ print(e)
57
+ print('\nopenai.error.APIError\nRetrying...')
58
+ time.sleep(10)
59
+ except openai.error.APIConnectionError as e:
60
+ print(e)
61
+ print('\nopenai.error.APIConnectionError\nRetrying...')
62
+ time.sleep(10)
63
+ except Exception as e:
64
+ print(e)
65
+ return None
66
+
67
+
68
+
69
+
70
+ def check_path(path):
71
+ if not os.path.exists(path):
72
+ os.makedirs(path)
73
+
74
+
75
+ def set_device(gpu_id):
76
+ if gpu_id == -1:
77
+ return torch.device('cpu')
78
+ else:
79
+ return torch.device(
80
+ 'cuda:' + str(gpu_id) if torch.cuda.is_available() else 'cpu')
81
+
82
+ def load_plm(model_path='bert-base-uncased'):
83
+
84
+ tokenizer = AutoTokenizer.from_pretrained(model_path,)
85
+
86
+ print("Load Model:", model_path)
87
+
88
+ model = AutoModel.from_pretrained(model_path,low_cpu_mem_usage=True,)
89
+ return tokenizer, model
90
+
91
+ def load_json(file):
92
+ with open(file, 'r') as f:
93
+ data = json.load(f)
94
+ return data
95
+
96
+ def clean_text(raw_text):
97
+ if isinstance(raw_text, list):
98
+ new_raw_text=[]
99
+ for raw in raw_text:
100
+ raw = html.unescape(raw)
101
+ raw = re.sub(r'</?\w+[^>]*>', '', raw)
102
+ raw = re.sub(r'["\n\r]*', '', raw)
103
+ new_raw_text.append(raw.strip())
104
+ cleaned_text = ' '.join(new_raw_text)
105
+ else:
106
+ if isinstance(raw_text, dict):
107
+ cleaned_text = str(raw_text)[1:-1].strip()
108
+ else:
109
+ cleaned_text = raw_text.strip()
110
+ cleaned_text = html.unescape(cleaned_text)
111
+ cleaned_text = re.sub(r'</?\w+[^>]*>', '', cleaned_text)
112
+ cleaned_text = re.sub(r'["\n\r]*', '', cleaned_text)
113
+ index = -1
114
+ while -index < len(cleaned_text) and cleaned_text[index] == '.':
115
+ index -= 1
116
+ index += 1
117
+ if index == 0:
118
+ cleaned_text = cleaned_text + '.'
119
+ else:
120
+ cleaned_text = cleaned_text[:index] + '.'
121
+ if len(cleaned_text) >= 2000:
122
+ cleaned_text = ''
123
+ return cleaned_text
124
+
125
+ def load_pickle(filename):
126
+ with open(filename, "rb") as f:
127
+ return pickle.load(f)
128
+
129
+
130
+ def make_inters_in_order(inters):
131
+ user2inters, new_inters = collections.defaultdict(list), list()
132
+ for inter in inters:
133
+ user, item, rating, timestamp = inter
134
+ user2inters[user].append((user, item, rating, timestamp))
135
+ for user in user2inters:
136
+ user_inters = user2inters[user]
137
+ user_inters.sort(key=lambda d: d[3])
138
+ for inter in user_inters:
139
+ new_inters.append(inter)
140
+ return new_inters
141
+
142
+ def write_json_file(dic, file):
143
+ print('Writing json file: ',file)
144
+ with open(file, 'w') as fp:
145
+ json.dump(dic, fp, indent=4)
146
+
147
+ def write_remap_index(unit2index, file):
148
+ print('Writing remap file: ',file)
149
+ with open(file, 'w') as fp:
150
+ for unit in unit2index:
151
+ fp.write(unit + '\t' + str(unit2index[unit]) + '\n')
152
+
153
+
154
+ intention_prompt = "After purchasing a {dataset_full_name} item named \"{item_title}\", the user left a comment expressing his opinion and personal preferences. The user's comment is as follows: \n\"{review}\" " \
155
+ "\nAs we all know, user comments often contain information about both their personal preferences and the characteristics of the item they interacted with. From this comment, you can infer both the user's personal preferences and the characteristics of the item. " \
156
+ "Please describe your inferred user preferences and item characteristics in the first person and in the following format:\n\nMy preferences: []\nThe item's characteristics: []\n\n" \
157
+ "Note that your inference of the personalized preferences should not include any information about the title of the item."
158
+
159
+
160
+ preference_prompt_1 = "Suppose the user has bought a variety of {dataset_full_name} items, they are: \n{item_titles}. \nAs we all know, these historically purchased items serve as a reflection of the user's personalized preferences. " \
161
+ "Please analyze the user's personalized preferences based on the items he has bought and provide a brief third-person summary of the user's preferences, highlighting the key factors that influence his choice of items. Avoid listing specific items and do not list multiple examples. " \
162
+ "Your analysis should be brief and in the third person."
163
+
164
+ preference_prompt_2 = "Given a chronological list of {dataset_full_name} items that a user has purchased, we can analyze his long-term and short-term preferences. Long-term preferences are inherent characteristics of the user, which are reflected in all the items he has interacted with over time. Short-term preferences are the user's recent preferences, which are reflected in some of the items he has bought more recently. " \
165
+ "To determine the user's long-term preferences, please analyze the contents of all the items he has bought. Look for common features that appear frequently across the user's shopping records. To determine the user's short-term preferences, focus on the items he has bought most recently. Identify any new or different features that have emerged in the user's shopping records. " \
166
+ "Here is a chronological list of items that the user has bought: \n{item_titles}. \nPlease provide separate analyses for the user's long-term and short-term preferences. Your answer should be concise and general, without listing specific items. Your answer should be in the third person and in the following format:\n\nLong-term preferences: []\nShort-term preferences: []\n\n"
167
+
168
+
169
+ # remove 'Magazine', 'Gift', 'Music', 'Kindle'
170
+ amazon18_dataset_list = [
171
+ 'Appliances', 'Beauty',
172
+ 'Fashion', 'Software', 'Luxury', 'Scientific', 'Pantry',
173
+ 'Instruments', 'Arts', 'Games', 'Office', 'Garden',
174
+ 'Food', 'Cell', 'CDs', 'Automotive', 'Toys',
175
+ 'Pet', 'Tools', 'Kindle', 'Sports', 'Movies',
176
+ 'Electronics', 'Home', 'Clothing', 'Books'
177
+ ]
178
+
179
+ amazon18_dataset2fullname = {
180
+ 'Beauty': 'All_Beauty',
181
+ 'Fashion': 'AMAZON_FASHION',
182
+ 'Appliances': 'Appliances',
183
+ 'Arts': 'Arts_Crafts_and_Sewing',
184
+ 'Automotive': 'Automotive',
185
+ 'Books': 'Books',
186
+ 'CDs': 'CDs_and_Vinyl',
187
+ 'Cell': 'Cell_Phones_and_Accessories',
188
+ 'Clothing': 'Clothing_Shoes_and_Jewelry',
189
+ 'Music': 'Digital_Music',
190
+ 'Electronics': 'Electronics',
191
+ 'Gift': 'Gift_Cards',
192
+ 'Food': 'Grocery_and_Gourmet_Food',
193
+ 'Home': 'Home_and_Kitchen',
194
+ 'Scientific': 'Industrial_and_Scientific',
195
+ 'Kindle': 'Kindle_Store',
196
+ 'Luxury': 'Luxury_Beauty',
197
+ 'Magazine': 'Magazine_Subscriptions',
198
+ 'Movies': 'Movies_and_TV',
199
+ 'Instruments': 'Musical_Instruments',
200
+ 'Office': 'Office_Products',
201
+ 'Garden': 'Patio_Lawn_and_Garden',
202
+ 'Pet': 'Pet_Supplies',
203
+ 'Pantry': 'Prime_Pantry',
204
+ 'Software': 'Software',
205
+ 'Sports': 'Sports_and_Outdoors',
206
+ 'Tools': 'Tools_and_Home_Improvement',
207
+ 'Toys': 'Toys_and_Games',
208
+ 'Games': 'Video_Games'
209
+ }
210
+
211
+ amazon14_dataset_list = [
212
+ 'Beauty','Toys','Sports'
213
+ ]
214
+
215
+ amazon14_dataset2fullname = {
216
+ 'Beauty': 'Beauty',
217
+ 'Sports': 'Sports_and_Outdoors',
218
+ 'Toys': 'Toys_and_Games',
219
+ }
220
+
221
+ # c1. c2. c3. c4.
222
+ amazon_text_feature1 = ['title', 'category', 'brand']
223
+
224
+ # re-order
225
+ amazon_text_feature1_ro1 = ['brand', 'main_cat', 'category', 'title']
226
+
227
+ # remove
228
+ amazon_text_feature1_re1 = ['title']
229
+
230
+ amazon_text_feature2 = ['title']
231
+
232
+ amazon_text_feature3 = ['description']
233
+
234
+ amazon_text_feature4 = ['description', 'main_cat', 'category', 'brand']
235
+
236
+ amazon_text_feature5 = ['title', 'description']
237
+
238
+
evaluate-finetuned.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import sys
5
+
6
+ import torch
7
+ import transformers
8
+ import torch.distributed as dist
9
+ from torch.utils.data.distributed import DistributedSampler
10
+ from torch.nn.parallel import DistributedDataParallel
11
+ from peft import PeftModel
12
+ from torch.utils.data import DataLoader
13
+ from tqdm import tqdm
14
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
15
+
16
+ from utils import *
17
+ from collator import TestCollator
18
+ from prompt import all_prompt
19
+ from evaluate import get_topk_results, get_metrics_results
20
+
21
+ parser = argparse.ArgumentParser(description = 'rqllama-evaluate')
22
+ parser = parse_evaluate_args(parser)
23
+ args = parser.parse_args()
24
+
25
+ set_seed(args.seed)
26
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
27
+ local_rank = int(os.environ.get("LOCAL_RANK") or 0)
28
+ torch.cuda.set_device(local_rank)
29
+ if local_rank == 0:
30
+ print(vars(args))
31
+
32
+ dist.init_process_group(backend = "nccl", world_size = world_size, rank = local_rank)
33
+
34
+ device_map = {"": local_rank}
35
+ device = torch.device("cuda",local_rank)
36
+
37
+ tokenizer = LlamaTokenizer.from_pretrained(args.ckpt_path)
38
+ base_model = LlamaForCausalLM.from_pretrained(args.base_model, torch_dtype=torch.float16, low_cpu_mem_usage = True, device_map = device_map)
39
+ base_model.resize_token_embeddings(len(tokenizer))
40
+ model = PeftModel.from_pretrained(base_model, args.ckpt_path, torch_dtype = torch.float16, device_map = device_map)
41
+
42
+ model = DistributedDataParallel(model, device_ids = [local_rank])
43
+
44
+ if args.test_prompt_ids == "all":
45
+ if args.test_task.lower() == "seqrec":
46
+ prompt_ids = range(len(all_prompt["seqrec"]))
47
+ elif args.test_task.lower() == "itemsearch":
48
+ prompt_ids = range(len(all_prompt["itemsearch"]))
49
+ elif args.test_task.lower() == "fusionseqrec":
50
+ prompt_ids = range(len(all_prompt["fusionseqrec"]))
51
+ else:
52
+ prompt_ids = [int(_) for _ in args.test_prompt_ids.split(",")]
53
+
54
+ test_data = load_test_dataset(args)
55
+ if local_rank == 0:
56
+ print("evaluate data num:", len(test_data))
57
+ ddp_sampler = DistributedSampler(test_data, num_replicas = world_size, rank = local_rank, drop_last = True)
58
+ collator = TestCollator(args, tokenizer)
59
+ all_items = test_data.get_all_items()
60
+ prefix_allowed_tokens = test_data.get_prefix_allowed_tokens_fn(tokenizer)
61
+ test_loader = DataLoader(
62
+ test_data,
63
+ batch_size = args.test_batch_size,
64
+ collate_fn = collator,
65
+ sampler = ddp_sampler,
66
+ num_workers = 4,
67
+ pin_memory = True
68
+ )
69
+
70
+ model.eval()
71
+
72
+ metrics = args.metrics.split(",")
73
+ all_prompt_results = []
74
+
75
+ print('prompts:', len(prompt_ids))
76
+
77
+ with torch.no_grad():
78
+ for prompt_id in prompt_ids:
79
+ if local_rank == 0:
80
+ print("Start prompt: ",prompt_id)
81
+ test_loader.dataset.set_prompt(prompt_id)
82
+ metrics_results = {}
83
+ total = 0
84
+
85
+ for step, batch in enumerate(tqdm(test_loader)):
86
+ inputs = batch[0].to(device)
87
+ targets = batch[1]
88
+ bs = len(targets)
89
+ num_beams = args.num_beams
90
+
91
+ while True:
92
+ try:
93
+ output = model.module.generate(
94
+ input_ids = inputs["input_ids"],
95
+ attention_mask = inputs["attention_mask"],
96
+ max_new_tokens = 10,
97
+ prefix_allowed_tokens_fn = prefix_allowed_tokens,
98
+ num_beams = num_beams,
99
+ num_return_sequences = num_beams,
100
+ output_scores = True,
101
+ return_dict_in_generate = True,
102
+ early_stopping = True,
103
+ )
104
+ break
105
+ except torch.cuda.OutOfMemoryError as e:
106
+ print("Out of memory!")
107
+ num_beams = num_beams -1
108
+ print("Beam:", num_beams)
109
+ except Exception:
110
+ raise RuntimeError
111
+ output_ids = output["sequences"]
112
+ scores = output["sequences_scores"]
113
+
114
+ # output_ids.shape: torch.Size([20, 101])
115
+ # scores.shape: torch.Size([20])
116
+
117
+ output = tokenizer.batch_decode(output_ids, skip_special_tokens = True)
118
+ # output.length: 20
119
+ '''
120
+ Below is an instruction that describes a task.
121
+ Write a response that appropriately completes the request.\n\n
122
+ ### Instruction:\nThe user has interacted with items <a-213> <b-171> <c-26> <d-74> <p-0> , <a-14> <b-33> <c-196> <d-121> <p-0> ,
123
+ <a-213> <b-23> <c-128> <d-13> <p-8> , <a-1> <b-23> <c-68> <d-71> <p-1> in chronological order.
124
+ Can you predict the next possible item that the user may expect?\n\n
125
+ ### Response: <a-9> <b-23> <c-123> <d-85> <p-2>
126
+ '''
127
+
128
+ topk_res = get_topk_results(
129
+ output,
130
+ scores,
131
+ targets,
132
+ num_beams,
133
+ all_items = all_items if args.filter_items else None
134
+ )
135
+
136
+ bs_gather_list = [None for _ in range(world_size)]
137
+ dist.all_gather_object(obj=bs, object_list=bs_gather_list)
138
+ total += sum(bs_gather_list)
139
+ res_gather_list = [None for _ in range(world_size)]
140
+ dist.all_gather_object(obj=topk_res, object_list=res_gather_list)
141
+
142
+ if local_rank == 0:
143
+ all_device_topk_res = []
144
+ for ga_res in res_gather_list:
145
+ all_device_topk_res += ga_res
146
+ batch_metrics_res = get_metrics_results(all_device_topk_res, metrics)
147
+ for m, res in batch_metrics_res.items():
148
+ if m not in metrics_results:
149
+ metrics_results[m] = res
150
+ else:
151
+ metrics_results[m] += res
152
+
153
+ if (step + 1) % 50 == 0:
154
+ temp = {}
155
+ for m in metrics_results:
156
+ temp[m] = metrics_results[m] / total
157
+ print(temp)
158
+ dist.barrier()
159
+
160
+ if local_rank == 0:
161
+ for m in metrics_results:
162
+ metrics_results[m] = metrics_results[m] / total
163
+ all_prompt_results.append(metrics_results)
164
+ print("======================================================")
165
+ print("Prompt {} results: ".format(prompt_id), metrics_results)
166
+ print("======================================================")
167
+ print("")
168
+ dist.barrier()
169
+ dist.barrier()
170
+
171
+ if local_rank == 0:
172
+ mean_results = {}
173
+ min_results = {}
174
+ max_results = {}
175
+
176
+ for m in metrics:
177
+ all_res = [_[m] for _ in all_prompt_results]
178
+ mean_results[m] = sum(all_res)/len(all_res)
179
+ min_results[m] = min(all_res)
180
+ max_results[m] = max(all_res)
181
+
182
+ print("======================================================")
183
+ print("Mean results: ", mean_results)
184
+ print("Min results: ", min_results)
185
+ print("Max results: ", max_results)
186
+ print("======================================================")
187
+
188
+ save_data={}
189
+ save_data["test_prompt_ids"] = args.test_prompt_ids
190
+ save_data["mean_results"] = mean_results
191
+ save_data["min_results"] = min_results
192
+ save_data["max_results"] = max_results
193
+ save_data["all_prompt_results"] = all_prompt_results
194
+
195
+ with open(args.results_file, "w") as f:
196
+ json.dump(save_data, f, indent = 4)
197
+ print("Save file: ", args.results_file)
evaluate.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ def get_topk_results(predictions, scores, targets, k, all_items=None):
4
+ # target: ['<a-5><b-248><c-226><d-145>']
5
+ results = []
6
+ B = len(targets)
7
+ predictions = [_.split("Response:")[-1] for _ in predictions]
8
+ predictions = [_.strip().replace(" ","") for _ in predictions]
9
+ # prediction: ['<a-9><b-70><c-10><d-21>', '<a-5><b-88><c-103><d-74>', '<a-29><b-70><c-36><d-113>']
10
+
11
+ if all_items is not None:
12
+ for i, seq in enumerate(predictions):
13
+ if seq not in all_items:
14
+ scores[i] = -1000
15
+
16
+ for b in range(B):
17
+ batch_seqs = predictions[b * k: (b + 1) * k]
18
+ batch_scores = scores[b * k: (b + 1) * k]
19
+
20
+ pairs = [(a, b) for a, b in zip(batch_seqs, batch_scores)]
21
+ sorted_pairs = sorted(pairs, key=lambda x: x[1], reverse=True)
22
+ target_item = targets[b]
23
+ one_results = []
24
+ for sorted_pred in sorted_pairs:
25
+ if sorted_pred[0] == target_item:
26
+ one_results.append(1)
27
+ else:
28
+ one_results.append(0)
29
+
30
+ results.append(one_results)
31
+
32
+ # result: [[0, 0, 0]]
33
+ return results
34
+
35
+ def get_metrics_results(topk_results, metrics):
36
+ res = {}
37
+ for m in metrics:
38
+ if m.lower().startswith("hit"):
39
+ k = int(m.split("@")[1])
40
+ res[m] = hit_k(topk_results, k)
41
+ elif m.lower().startswith("ndcg"):
42
+ k = int(m.split("@")[1])
43
+ res[m] = ndcg_k(topk_results, k)
44
+ else:
45
+ raise NotImplementedError
46
+
47
+ return res
48
+
49
+
50
+ def ndcg_k(topk_results, k):
51
+
52
+ ndcg = 0.0
53
+ for row in topk_results:
54
+ res = row[:k]
55
+ one_ndcg = 0.0
56
+ for i in range(len(res)):
57
+ one_ndcg += res[i] / math.log(i + 2, 2)
58
+ ndcg += one_ndcg
59
+ return ndcg
60
+
61
+
62
+ def hit_k(topk_results, k):
63
+ hit = 0.0
64
+ for row in topk_results:
65
+ res = row[:k]
66
+ if sum(res) > 0:
67
+ hit += 1
68
+ return hit
69
+
fine-tune.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from typing import List
5
+
6
+ import torch
7
+ import transformers
8
+ from peft import PeftModel
9
+ from peft import (
10
+ TaskType,
11
+ LoraConfig,
12
+ get_peft_model,
13
+ get_peft_model_state_dict,
14
+ set_peft_model_state_dict,
15
+ )
16
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
17
+
18
+ from utils import *
19
+ from collator import Collator
20
+
21
+ import argparse
22
+ from utils import *
23
+ from rq_llama import *
24
+
25
+ parser = argparse.ArgumentParser(description = 'rqllama-finetune')
26
+ parser = parse_finetune_args(parser)
27
+ args = parser.parse_args()
28
+
29
+ set_seed(args.seed)
30
+ ensure_dir(args.output_dir)
31
+
32
+ device_map = "auto"
33
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
34
+ ddp = world_size != 1
35
+ local_rank = int(os.environ.get("LOCAL_RANK") or 0)
36
+ if local_rank == 0:
37
+ print(vars(args))
38
+
39
+ if ddp:
40
+ device_map = {"": local_rank}
41
+
42
+ train_data, valid_data = load_finetune_datasets(args)
43
+
44
+ rqllama = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map)
45
+ tokenizer = rqllama.tokenizer
46
+ # PeftModelForCausalLM
47
+ model = rqllama.model
48
+ device = rqllama.device
49
+
50
+ postfix = '<p-{}>'
51
+ new_tokens = []
52
+ new_ids = list(range(args.reindex))
53
+ for i in new_ids:
54
+ new_tokens.append(postfix.format(int(i)))
55
+ tokenizer.add_tokens(new_tokens)
56
+
57
+ if local_rank == 0:
58
+ print("token num:", len(rqllama.tokenizer))
59
+ print("data num:", len(train_data))
60
+
61
+ collator = Collator(args, tokenizer)
62
+
63
+ # Re-index Embedding
64
+ new_ids = torch.tensor(new_ids, dtype = torch.float16).reshape(-1,1)
65
+ re_index_emb = torch.nn.Linear(1, model.config.hidden_size, dtype = torch.float16).to(device)
66
+ new_embeddings = re_index_emb(new_ids.to(device))
67
+ # PeftModelForCausalLM -> LlamaForCausalLM -> LlamaModel
68
+ model.model.model.embed_tokens.original_module.weight.data = torch.cat([model.model.model.embed_tokens.original_module.weight.data, new_embeddings], dim = 0)
69
+ model.model.model.embed_tokens.modules_to_save.default.weight.data = torch.cat([model.model.model.embed_tokens.modules_to_save.default.weight.data, new_embeddings], dim = 0)
70
+
71
+ new_lm_head = torch.randn(args.reindex, model.config.hidden_size, requires_grad = True).to(device)
72
+ # print('new_lm_head:',new_lm_head.requires_grad)
73
+ # PeftModelForCausalLM -> LlamaForCausalLM
74
+ model.model.lm_head.original_module.weight.data = torch.cat([model.model.lm_head.original_module.weight.data, new_lm_head], dim = 0)
75
+ model.model.lm_head.modules_to_save.default.weight.data = torch.cat([model.model.lm_head.modules_to_save.default.weight.data, new_lm_head], dim = 0)
76
+
77
+ model.config.vocab_size = len(tokenizer)
78
+
79
+ # print(model.model.model.embed_tokens.original_module.weight.shape)
80
+ # print(len(tokenizer))
81
+
82
+ model.train()
83
+
84
+ if local_rank == 0:
85
+ model.print_trainable_parameters()
86
+
87
+ trainer = transformers.Trainer(
88
+ model = model,
89
+ train_dataset = train_data,
90
+ eval_dataset = valid_data,
91
+ args = transformers.TrainingArguments(
92
+ seed = args.seed,
93
+ per_device_train_batch_size = args.per_device_batch_size,
94
+ per_device_eval_batch_size = args.per_device_batch_size,
95
+ gradient_accumulation_steps = args.gradient_accumulation_steps,
96
+ warmup_ratio = args.warmup_ratio,
97
+ num_train_epochs = args.epochs,
98
+ learning_rate = args.learning_rate,
99
+ weight_decay = args.weight_decay,
100
+ lr_scheduler_type = args.lr_scheduler_type,
101
+ fp16 = args.fp16,
102
+ bf16 = args.bf16,
103
+ logging_steps = args.logging_step,
104
+ optim = args.optim,
105
+ gradient_checkpointing = True,
106
+ evaluation_strategy = args.save_and_eval_strategy,
107
+ save_strategy = args.save_and_eval_strategy,
108
+ eval_steps = args.save_and_eval_steps,
109
+ save_steps = args.save_and_eval_steps,
110
+ output_dir = args.output_dir,
111
+ save_total_limit = 50,
112
+ load_best_model_at_end = True,
113
+ deepspeed = args.deepspeed,
114
+ ddp_find_unused_parameters = False if ddp else None,
115
+ report_to = None,
116
+ eval_delay = 1 if args.save_and_eval_strategy=="epoch" else 2000,
117
+ dataloader_num_workers = args.dataloader_num_workers,
118
+ dataloader_prefetch_factor = args.dataloader_prefetch_factor,
119
+ remove_unused_columns = args.remove_unused_columns,
120
+ ),
121
+ tokenizer = tokenizer,
122
+ data_collator = collator,
123
+ )
124
+ model.config.use_cache = False
125
+
126
+ if torch.__version__ >= "2" and sys.platform != "win32":
127
+ model = torch.compile(model)
128
+
129
+ trainer.train(resume_from_checkpoint = args.resume_from_checkpoint)
130
+
131
+ trainer.save_state()
132
+ trainer.save_model(output_dir = args.output_dir)
133
+
134
+ if local_rank == 0:
135
+ print('rqllama fine-tune finished.')
136
+
137
+ import smtplib
138
+ from email.mime.text import MIMEText
139
+ mail_host = 'smtp.qq.com'
140
+ mail_code = 'ouzplpngooqndjcb'
141
+ sender = '1849334588@qq.com'
142
+ receiver = 'esperanto1949@foxmail.com'
143
+
144
+ task = '[v39: finetune twin-tower]'
145
+ message = MIMEText('Task {task} Finished'.format(task = task), 'plain', 'utf-8')
146
+ message['Subject'] = 'Auto Email'
147
+ message['From'] = sender
148
+ message['To'] = receiver
149
+
150
+ server = smtplib.SMTP_SSL("smtp.qq.com", 465)
151
+ server.login(sender, mail_code)
152
+ server.sendmail(sender, receiver, message.as_string())
153
+
154
+ server.quit()
finetune.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import sys
5
+ from typing import List
6
+
7
+ import torch
8
+ import transformers
9
+
10
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
11
+
12
+ from utils import *
13
+ from collator import Collator
14
+
15
+ def train(args):
16
+
17
+ set_seed(args.seed)
18
+ ensure_dir(args.output_dir)
19
+
20
+ device_map = "auto"
21
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
22
+ ddp = world_size != 1
23
+ local_rank = int(os.environ.get("LOCAL_RANK") or 0)
24
+ if local_rank == 0:
25
+ print(vars(args))
26
+
27
+ if ddp:
28
+ device_map = {"": local_rank}
29
+
30
+ config = LlamaConfig.from_pretrained(args.base_model)
31
+ tokenizer = LlamaTokenizer.from_pretrained(
32
+ args.base_model,
33
+ model_max_length = args.model_max_length,
34
+ padding_side="right",
35
+ )
36
+ tokenizer.pad_token_id = 0
37
+ gradient_checkpointing = True
38
+
39
+ train_data, valid_data = load_datasets(args)
40
+ add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens())
41
+ config.vocab_size = len(tokenizer)
42
+ if local_rank == 0:
43
+ print("add {} new token.".format(add_num))
44
+ print("data num:", len(train_data))
45
+ tokenizer.save_pretrained(args.output_dir)
46
+ config.save_pretrained(args.output_dir)
47
+
48
+ collator = Collator(args, tokenizer)
49
+
50
+
51
+ model = LlamaForCausalLM.from_pretrained(
52
+ args.base_model,
53
+ # torch_dtype=torch.float16,
54
+ device_map=device_map,
55
+ )
56
+ model.resize_token_embeddings(len(tokenizer))
57
+
58
+
59
+ if not ddp and torch.cuda.device_count() > 1:
60
+ model.is_parallelizable = True
61
+ model.model_parallel = True
62
+
63
+
64
+ trainer = transformers.Trainer(
65
+ model=model,
66
+ train_dataset=train_data,
67
+ eval_dataset=valid_data,
68
+ args=transformers.TrainingArguments(
69
+ seed=args.seed,
70
+ per_device_train_batch_size=args.per_device_batch_size,
71
+ per_device_eval_batch_size=args.per_device_batch_size,
72
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
73
+ warmup_ratio=args.warmup_ratio,
74
+ num_train_epochs=args.epochs,
75
+ learning_rate=args.learning_rate,
76
+ weight_decay=args.weight_decay,
77
+ lr_scheduler_type=args.lr_scheduler_type,
78
+ fp16=args.fp16,
79
+ bf16=args.bf16,
80
+ logging_steps=args.logging_step,
81
+ optim=args.optim,
82
+ gradient_checkpointing=gradient_checkpointing,
83
+ evaluation_strategy=args.save_and_eval_strategy,
84
+ save_strategy=args.save_and_eval_strategy,
85
+ eval_steps=args.save_and_eval_steps,
86
+ save_steps=args.save_and_eval_steps,
87
+ output_dir=args.output_dir,
88
+ save_total_limit=5,
89
+ load_best_model_at_end=True,
90
+ deepspeed=args.deepspeed,
91
+ ddp_find_unused_parameters=False if ddp else None,
92
+ report_to=None,
93
+ eval_delay= 1 if args.save_and_eval_strategy=="epoch" else 2000,
94
+ ),
95
+ tokenizer=tokenizer,
96
+ data_collator=collator,
97
+ )
98
+ model.config.use_cache = False
99
+
100
+
101
+ trainer.train(
102
+ resume_from_checkpoint=args.resume_from_checkpoint,
103
+ )
104
+
105
+ trainer.save_state()
106
+ trainer.save_model(output_dir=args.output_dir)
107
+
108
+
109
+
110
+
111
+ if __name__ == "__main__":
112
+ parser = argparse.ArgumentParser(description='LLMRec')
113
+ parser = parse_global_args(parser)
114
+ parser = parse_train_args(parser)
115
+ parser = parse_dataset_args(parser)
116
+
117
+ args = parser.parse_args()
118
+
119
+ train(args)
generate_embeddings.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import collections
3
+ import json
4
+ import logging
5
+ import argparse
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from time import time
10
+ from torch import optim
11
+ from tqdm import tqdm
12
+ from torch.utils.data import DataLoader
13
+ from rq_llama import *
14
+
15
+ def parse_args():
16
+ parser = argparse.ArgumentParser(description = "Index")
17
+ parser.add_argument("--ckpt_path", type = str, default = "", help = "")
18
+ parser.add_argument("--item_save_path", type = str, default = "", help = "")
19
+ parser.add_argument("--user_save_path", type = str, default = "", help = "")
20
+ parser.add_argument("--device_map", type = str, default = "1", help = "gpu or cpu")
21
+ return parser.parse_args()
22
+
23
+ args = parse_args()
24
+ print(args)
25
+ device_map = {'': int(args.device_map)}
26
+ MODEL = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map)
27
+ MODEL.eval()
28
+ device = MODEL.device
29
+ llama = MODEL.model.get_decoder()
30
+ tokenizer = MODEL.tokenizer
31
+ item_texts = MODEL.item_texts
32
+ user_texts = MODEL.user_texts
33
+
34
+ all_idx = []
35
+ all_embeddings = []
36
+ with torch.no_grad():
37
+ for idx, text in tqdm(item_texts.items()):
38
+ item_text = text['title'] + ' ' + text['description']
39
+ item_ids = tokenizer(item_text, return_tensors = 'pt', padding = True, truncation = True).to(device)
40
+ item_emb = llama(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask)
41
+ item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1)
42
+ item_emb = item_emb.sum(dim = 1) / item_ids.attention_mask.sum(dim = -1, keepdim = True)
43
+
44
+ all_idx.append(idx)
45
+ all_embeddings.append(item_emb.detach().cpu().numpy().flatten().tolist())
46
+
47
+ results = {
48
+ 'id': all_idx,
49
+ 'emb': []
50
+ }
51
+
52
+ for emb in tqdm(all_embeddings):
53
+ str_emb = ''
54
+ for e in emb:
55
+ str_emb = str_emb + str(e) + ' '
56
+ results['emb'].append(str_emb[:-1])
57
+
58
+ df = pd.DataFrame(results)
59
+ df.to_csv(args.item_save_path, sep = '\t', header = 0, index = False)
60
+
61
+ all_idx = []
62
+ all_embeddings = []
63
+ with torch.no_grad():
64
+ for idx, text in tqdm(user_texts.items()):
65
+ user_text = ' '.join(text)
66
+ user_ids = tokenizer(user_text, return_tensors = 'pt', padding = True, truncation = True).to(device)
67
+ user_emb = llama(input_ids = user_ids.input_ids, attention_mask = user_ids.attention_mask)
68
+ user_emb = user_emb.last_hidden_state * user_ids.attention_mask.unsqueeze(-1)
69
+ user_emb = user_emb.sum(dim = 1) / user_ids.attention_mask.sum(dim = -1, keepdim = True)
70
+
71
+ all_idx.append(idx)
72
+ all_embeddings.append(user_emb.detach().cpu().numpy().flatten().tolist())
73
+
74
+ results = {
75
+ 'id': all_idx,
76
+ 'emb': []
77
+ }
78
+
79
+ for emb in tqdm(all_embeddings):
80
+ str_emb = ''
81
+ for e in emb:
82
+ str_emb = str_emb + str(e) + ' '
83
+ results['emb'].append(str_emb[:-1])
84
+
85
+ df = pd.DataFrame(results)
86
+ df.to_csv(args.user_save_path, sep = '\t', header = 0, index = False)
generate_indices.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import collections
3
+ import json
4
+ import logging
5
+ import argparse
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from time import time
10
+ from torch import optim
11
+ from tqdm import tqdm
12
+ from torch.utils.data import DataLoader
13
+
14
+ from rq_llama import *
15
+ from index.datasets import EmbDataset
16
+
17
+ def if_collided(all_indices_str):
18
+ tot_item = len(all_indices_str)
19
+ tot_indice = len(set(all_indices_str.tolist()))
20
+ return tot_item == tot_indice
21
+
22
+ def get_indices_count(all_indices_str):
23
+ indices_count = collections.defaultdict(int)
24
+ for index in all_indices_str:
25
+ indices_count[index] += 1
26
+ return indices_count
27
+
28
+ def get_collision_item(all_indices_str):
29
+ index2id = {}
30
+ for i, index in enumerate(all_indices_str):
31
+ if index not in index2id:
32
+ index2id[index] = []
33
+ index2id[index].append(i)
34
+ collision_item_groups = []
35
+ for index in index2id:
36
+ if len(index2id[index]) > 1:
37
+ collision_item_groups.append(index2id[index])
38
+ return collision_item_groups
39
+
40
+ def parse_args():
41
+ parser = argparse.ArgumentParser(description = "Index")
42
+ parser.add_argument("--ckpt_path", type = str, default = "", help = "")
43
+ parser.add_argument("--item_data_path", type = str, default = "", help = "")
44
+ parser.add_argument("--user_data_path", type = str, default = "", help = "")
45
+ parser.add_argument("--save_path", type = str, default = "", help = "")
46
+ parser.add_argument("--device_map", type = str, default = "1", help = "gpu or cpu")
47
+ return parser.parse_args()
48
+
49
+ args = parse_args()
50
+ print(args)
51
+
52
+ device_map = {'': int(args.device_map)}
53
+ MODEL = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map)
54
+ MODEL.eval()
55
+ device = MODEL.device
56
+ postfix = '<p-{}>'
57
+
58
+ data = EmbDataset(args.item_data_path)
59
+ data_loader = DataLoader(data, num_workers = 4, batch_size = 64, shuffle = False, pin_memory = True)
60
+ rqvae = MODEL.item_rqvae
61
+ prefix = MODEL.prefix
62
+
63
+ index_table = {}
64
+ all_indices = []
65
+ all_indices_str = []
66
+ with torch.no_grad():
67
+ for x in tqdm(data_loader):
68
+ indices = rqvae.get_indices(x.to(device), False)
69
+ indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
70
+ for index in indices:
71
+ code = []
72
+ for i, ind in enumerate(index):
73
+ code.append(prefix[i].format(int(ind)))
74
+
75
+ if str(code) in index_table:
76
+ index_table[str(code)] += 1
77
+ else:
78
+ index_table[str(code)] = 0
79
+ code.append(postfix.format(index_table[str(code)]))
80
+
81
+ all_indices.append(code)
82
+ all_indices_str.append(str(code))
83
+
84
+ all_indices = np.array(all_indices)
85
+ all_indices_str = np.array(all_indices_str)
86
+
87
+ print("All indices number: ", len(all_indices))
88
+ print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
89
+ print('Re-index number:', max(index_table.values()))
90
+
91
+ all_indices_dict = {}
92
+ for item, indices in enumerate(all_indices.tolist()):
93
+ all_indices_dict[item] = list(indices)
94
+
95
+ reindex_dict = {'reindex': max(index_table.values())}
96
+
97
+ json_path = os.path.join(args.save_path,'indices.item.json')
98
+ with open(json_path, 'w',encoding = 'utf-8') as f:
99
+ json.dump(all_indices_dict, f)
100
+
101
+ reindex_path = os.path.join(args.save_path,'reindex.item.json')
102
+ with open(reindex_path, 'w',encoding = 'utf-8') as f:
103
+ json.dump(reindex_dict, f)
104
+
105
+ data = EmbDataset(args.user_data_path)
106
+ data_loader = DataLoader(data, num_workers = 4, batch_size = 64, shuffle = False, pin_memory = True)
107
+ rqvae = MODEL.user_rqvae
108
+ prefix = MODEL.user_prefix
109
+
110
+ # index_table = {}
111
+ all_indices = []
112
+ all_indices_str = []
113
+ with torch.no_grad():
114
+ for x in tqdm(data_loader):
115
+ indices = rqvae.get_indices(x.to(device), False)
116
+ indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
117
+ for index in indices:
118
+ code = []
119
+ for i, ind in enumerate(index):
120
+ code.append(prefix[i].format(int(ind)))
121
+
122
+ # if str(code) in index_table:
123
+ # index_table[str(code)] += 1
124
+ # else:
125
+ # index_table[str(code)] = 0
126
+ # code.append(postfix.format(index_table[str(code)]))
127
+
128
+ all_indices.append(code)
129
+ all_indices_str.append(str(code))
130
+
131
+ all_indices = np.array(all_indices)
132
+ all_indices_str = np.array(all_indices_str)
133
+
134
+ print("All indices number: ", len(all_indices))
135
+ print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
136
+ # print('Re-index number:', max(index_table.values()))
137
+
138
+ all_indices_dict = {}
139
+ for item, indices in enumerate(all_indices.tolist()):
140
+ all_indices_dict[item] = list(indices)
141
+
142
+ # reindex_dict = {'reindex': max(index_table.values())}
143
+
144
+ json_path = os.path.join(args.save_path,'indices.user.json')
145
+ with open(json_path, 'w',encoding = 'utf-8') as f:
146
+ json.dump(all_indices_dict, f)
147
+
148
+ # reindex_path = os.path.join(args.save_path,'reindex.user.json')
149
+ # with open(reindex_path, 'w',encoding = 'utf-8') as f:
150
+ # json.dump(reindex_dict, f)
generate_random_indices.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import collections
3
+ import json
4
+ import logging
5
+ import argparse
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from time import time
10
+ from torch import optim
11
+ from tqdm import tqdm
12
+ import torch.utils.data as data
13
+ from torch.utils.data import DataLoader
14
+ from index.models.rqvae import RQVAE
15
+ # from rq_llama import *
16
+ # from index.datasets import EmbDataset
17
+ import random
18
+
19
+ class NpyDataset(data.Dataset):
20
+ def __init__(self, data_path):
21
+ self.data_path = data_path
22
+ self.embeddings = np.load(data_path)
23
+ self.dim = self.embeddings.shape[-1]
24
+
25
+ def __getitem__(self, index):
26
+ emb = self.embeddings[index]
27
+ tensor_emb = torch.FloatTensor(emb)
28
+ return tensor_emb
29
+
30
+ def __len__(self):
31
+ return len(self.embeddings)
32
+
33
+ def if_collided(all_indices_str):
34
+ tot_item = len(all_indices_str)
35
+ tot_indice = len(set(all_indices_str.tolist()))
36
+ return tot_item == tot_indice
37
+
38
+ def get_indices_count(all_indices_str):
39
+ indices_count = collections.defaultdict(int)
40
+ for index in all_indices_str:
41
+ indices_count[index] += 1
42
+ return indices_count
43
+
44
+ def get_collision_item(all_indices_str):
45
+ index2id = {}
46
+ for i, index in enumerate(all_indices_str):
47
+ if index not in index2id:
48
+ index2id[index] = []
49
+ index2id[index].append(i)
50
+ collision_item_groups = []
51
+ for index in index2id:
52
+ if len(index2id[index]) > 1:
53
+ collision_item_groups.append(index2id[index])
54
+ return collision_item_groups
55
+
56
+ def parse_args():
57
+ parser = argparse.ArgumentParser(description = "Index")
58
+ parser.add_argument("--item_model_path", type = str, default = "", help = "")
59
+ parser.add_argument("--item_data_path", type = str, default = "", help = "")
60
+ parser.add_argument("--user_model_path", type = str, default = "", help = "")
61
+ parser.add_argument("--user_data_path", type = str, default = "", help = "")
62
+ # parser.add_argument("--save_path", type = str, default = "", help = "")
63
+ parser.add_argument("--device", type = str, default = "cuda:0", help = "gpu or cpu")
64
+ return parser.parse_args()
65
+
66
+ generate_args = parse_args()
67
+ print(generate_args)
68
+
69
+ device = torch.device(generate_args.device)
70
+
71
+ # generate item index
72
+ ckpt = torch.load(os.path.join(generate_args.item_model_path, 'best_collision_model.pth'), map_location = torch.device('cpu'))
73
+ args = ckpt['args']
74
+ state_dict = ckpt['state_dict']
75
+
76
+ data = NpyDataset(generate_args.item_data_path)
77
+ data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True)
78
+ # model = RQVAE(
79
+ # in_dim = data.dim,
80
+ # num_emb_list = args.num_emb_list,
81
+ # e_dim = args.e_dim,
82
+ # layers = args.layers,
83
+ # dropout_prob = args.dropout_prob,
84
+ # bn = args.bn,
85
+ # loss_type = args.loss_type,
86
+ # quant_loss_weight = args.quant_loss_weight,
87
+ # kmeans_init = args.kmeans_init,
88
+ # kmeans_iters = args.kmeans_iters,
89
+ # sk_epsilons = args.sk_epsilons,
90
+ # sk_iters = args.sk_iters,
91
+ # )
92
+ # model.load_state_dict(state_dict)
93
+ # model = model.to(device)
94
+ # model.eval()
95
+ # print(model)
96
+
97
+ prefix = ["<a_{}>","<b_{}>","<c_{}>","<d_{}>","<e_{}>"]
98
+ postfix = "<p_{}>"
99
+
100
+ index_table = {}
101
+ all_indices = []
102
+ all_indices_str = []
103
+ with torch.no_grad():
104
+ for x in tqdm(data_loader):
105
+ # indices = model.get_indices(x.to(device), False)
106
+ # indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
107
+
108
+ indices = np.random.randint(0, 256, size = (64, 4), dtype = int)
109
+ for index in indices:
110
+ code = []
111
+ for i, ind in enumerate(index):
112
+ code.append(prefix[i].format(int(ind)))
113
+
114
+ if str(code) in index_table:
115
+ index_table[str(code)] += 1
116
+ else:
117
+ index_table[str(code)] = 0
118
+ code.append(postfix.format(index_table[str(code)]))
119
+
120
+ all_indices.append(code)
121
+ all_indices_str.append(str(code))
122
+
123
+ all_indices = np.array(all_indices)
124
+ all_indices_str = np.array(all_indices_str)
125
+
126
+ print("All indices number: ", len(all_indices))
127
+ print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
128
+ print('Re-index number:', max(index_table.values()))
129
+
130
+ all_indices_dict = {}
131
+ for item, indices in enumerate(all_indices.tolist()):
132
+ all_indices_dict[item] = list(indices)
133
+ reindex_dict = {'reindex': max(index_table.values())}
134
+
135
+ item_index_path = os.path.join(generate_args.item_model_path, 'indices.random.item.json')
136
+ with open(item_index_path, 'w', encoding = 'utf-8') as f:
137
+ json.dump(all_indices_dict, f)
138
+
139
+ item_reindex_path = os.path.join(generate_args.item_model_path, 'reindex.random.item.json')
140
+ with open(item_reindex_path, 'w', encoding = 'utf-8') as f:
141
+ json.dump(reindex_dict, f)
142
+
143
+ # generate user index
144
+ ckpt = torch.load(os.path.join(generate_args.user_model_path, 'best_collision_model.pth'), map_location = torch.device('cpu'))
145
+ args = ckpt['args']
146
+ state_dict = ckpt['state_dict']
147
+
148
+ data = NpyDataset(generate_args.user_data_path)
149
+ data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True)
150
+ # model = RQVAE(
151
+ # in_dim = data.dim,
152
+ # num_emb_list = args.num_emb_list,
153
+ # e_dim = args.e_dim,
154
+ # layers = args.layers,
155
+ # dropout_prob = args.dropout_prob,
156
+ # bn = args.bn,
157
+ # loss_type = args.loss_type,
158
+ # quant_loss_weight = args.quant_loss_weight,
159
+ # kmeans_init = args.kmeans_init,
160
+ # kmeans_iters = args.kmeans_iters,
161
+ # sk_epsilons = args.sk_epsilons,
162
+ # sk_iters = args.sk_iters,
163
+ # )
164
+ # model.load_state_dict(state_dict)
165
+ # model = model.to(device)
166
+ # model.eval()
167
+ # print(model)
168
+
169
+ prefix = ['<z-{}>','<y-{}>','<x-{}>','<w-{}>','<v-{}>']
170
+
171
+ all_indices = []
172
+ all_indices_str = []
173
+ with torch.no_grad():
174
+ for x in tqdm(data_loader):
175
+ # indices = rqvae.get_indices(x.to(device), False)
176
+ # indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
177
+ indices = np.random.randint(0, 256, size = (64, 4), dtype = int)
178
+ for index in indices:
179
+ code = []
180
+ for i, ind in enumerate(index):
181
+ code.append(prefix[i].format(int(ind)))
182
+
183
+ all_indices.append(code)
184
+ all_indices_str.append(str(code))
185
+
186
+ all_indices = np.array(all_indices)
187
+ all_indices_str = np.array(all_indices_str)
188
+
189
+ print("All indices number: ", len(all_indices))
190
+ print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
191
+
192
+ all_indices_dict = {}
193
+ for item, indices in enumerate(all_indices.tolist()):
194
+ all_indices_dict[item] = list(indices)
195
+
196
+ json_path = os.path.join(generate_args.user_model_path, 'indices.random.user.json')
197
+ with open(json_path, 'w', encoding = 'utf-8') as f:
198
+ json.dump(all_indices_dict, f)
generate_static_indices.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import collections
3
+ import json
4
+ import logging
5
+ import argparse
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from time import time
10
+ from torch import optim
11
+ from tqdm import tqdm
12
+ from torch.utils.data import DataLoader
13
+ from index.models.rqvae import RQVAE
14
+ # from rq_llama import *
15
+ # from index.datasets import EmbDataset
16
+
17
+ class NpyDataset(data.Dataset):
18
+ def __init__(self, data_path):
19
+ self.data_path = data_path
20
+ self.embeddings = np.load(data_path)
21
+ self.dim = self.embeddings.shape[-1]
22
+
23
+ def __getitem__(self, index):
24
+ emb = self.embeddings[index]
25
+ tensor_emb = torch.FloatTensor(emb)
26
+ return tensor_emb
27
+
28
+ def __len__(self):
29
+ return len(self.embeddings)
30
+
31
+ def if_collided(all_indices_str):
32
+ tot_item = len(all_indices_str)
33
+ tot_indice = len(set(all_indices_str.tolist()))
34
+ return tot_item == tot_indice
35
+
36
+ def get_indices_count(all_indices_str):
37
+ indices_count = collections.defaultdict(int)
38
+ for index in all_indices_str:
39
+ indices_count[index] += 1
40
+ return indices_count
41
+
42
+ def get_collision_item(all_indices_str):
43
+ index2id = {}
44
+ for i, index in enumerate(all_indices_str):
45
+ if index not in index2id:
46
+ index2id[index] = []
47
+ index2id[index].append(i)
48
+ collision_item_groups = []
49
+ for index in index2id:
50
+ if len(index2id[index]) > 1:
51
+ collision_item_groups.append(index2id[index])
52
+ return collision_item_groups
53
+
54
+ def parse_args():
55
+ parser = argparse.ArgumentParser(description = "Index")
56
+ parser.add_argument("--item_model_path", type = str, default = "", help = "")
57
+ parser.add_argument("--item_data_path", type = str, default = "", help = "")
58
+ parser.add_argument("--user_model_path", type = str, default = "", help = "")
59
+ parser.add_argument("--user_data_path", type = str, default = "", help = "")
60
+ # parser.add_argument("--save_path", type = str, default = "", help = "")
61
+ parser.add_argument("--device", type = str, default = "cuda:0", help = "gpu or cpu")
62
+ return parser.parse_args()
63
+
64
+ generate_args = parse_args()
65
+ print(generate_args)
66
+
67
+ device = torch.device(generate_args.device)
68
+
69
+ # generate item index
70
+ ckpt = torch.load(generate_args.item_model_path, map_location = torch.device('cpu'))
71
+ args = ckpt['args']
72
+ state_dict = ckpt['state_dict']
73
+
74
+ data = NpyDataset(generate_args.item_data_path)
75
+ data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True)
76
+ model = RQVAE(
77
+ in_dim = data.dim,
78
+ num_emb_list = args.num_emb_list,
79
+ e_dim = args.e_dim,
80
+ layers = args.layers,
81
+ dropout_prob = args.dropout_prob,
82
+ bn = args.bn,
83
+ loss_type = args.loss_type,
84
+ quant_loss_weight = args.quant_loss_weight,
85
+ kmeans_init = args.kmeans_init,
86
+ kmeans_iters = args.kmeans_iters,
87
+ sk_epsilons = args.sk_epsilons,
88
+ sk_iters = args.sk_iters,
89
+ )
90
+ model.load_state_dict(state_dict)
91
+ model = model.to(device)
92
+ model.eval()
93
+ # print(model)
94
+
95
+ prefix = ["<a_{}>","<b_{}>","<c_{}>","<d_{}>","<e_{}>"]
96
+ postfix = "<p_{}>"
97
+
98
+ index_table = {}
99
+ all_indices = []
100
+ all_indices_str = []
101
+ with torch.no_grad():
102
+ for x in tqdm(data_loader):
103
+ indices = model.get_indices(x.to(device), False)
104
+ indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
105
+ for index in indices:
106
+ code = []
107
+ for i, ind in enumerate(index):
108
+ code.append(prefix[i].format(int(ind)))
109
+
110
+ if str(code) in index_table:
111
+ index_table[str(code)] += 1
112
+ else:
113
+ index_table[str(code)] = 0
114
+ code.append(postfix.format(index_table[str(code)]))
115
+
116
+ all_indices.append(code)
117
+ all_indices_str.append(str(code))
118
+
119
+ all_indices = np.array(all_indices)
120
+ all_indices_str = np.array(all_indices_str)
121
+
122
+ print("All indices number: ", len(all_indices))
123
+ print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
124
+ print('Re-index number:', max(index_table.values()))
125
+
126
+ all_indices_dict = {}
127
+ for item, indices in enumerate(all_indices.tolist()):
128
+ all_indices_dict[item] = list(indices)
129
+ reindex_dict = {'reindex': max(index_table.values())}
130
+
131
+ item_index_path = os.path.join(generate_args.item_model_path, 'indices.item.json')
132
+ with open(json_path, 'w', encoding = 'utf-8') as f:
133
+ json.dump(all_indices_dict, f)
134
+
135
+ item_reindex_path = os.path.join(generate_args.item_model_path, 'reindex.item.json')
136
+ with open(reindex_path, 'w', encoding = 'utf-8') as f:
137
+ json.dump(reindex_dict, f)
138
+
139
+ # generate user index
140
+ ckpt = torch.load(generate_args.user_model_path, map_location = torch.device('cpu'))
141
+ args = ckpt['args']
142
+ state_dict = ckpt['state_dict']
143
+
144
+ data = NpyDataset(generate_args.user_data_path)
145
+ data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True)
146
+ model = RQVAE(
147
+ in_dim = data.dim,
148
+ num_emb_list = args.num_emb_list,
149
+ e_dim = args.e_dim,
150
+ layers = args.layers,
151
+ dropout_prob = args.dropout_prob,
152
+ bn = args.bn,
153
+ loss_type = args.loss_type,
154
+ quant_loss_weight = args.quant_loss_weight,
155
+ kmeans_init = args.kmeans_init,
156
+ kmeans_iters = args.kmeans_iters,
157
+ sk_epsilons = args.sk_epsilons,
158
+ sk_iters = args.sk_iters,
159
+ )
160
+ model.load_state_dict(state_dict)
161
+ model = model.to(device)
162
+ model.eval()
163
+ # print(model)
164
+
165
+ prefix = ['<z-{}>','<y-{}>','<x-{}>','<w-{}>','<v-{}>']
166
+
167
+ all_indices = []
168
+ all_indices_str = []
169
+ with torch.no_grad():
170
+ for x in tqdm(data_loader):
171
+ indices = rqvae.get_indices(x.to(device), False)
172
+ indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
173
+ for index in indices:
174
+ code = []
175
+ for i, ind in enumerate(index):
176
+ code.append(prefix[i].format(int(ind)))
177
+
178
+ all_indices.append(code)
179
+ all_indices_str.append(str(code))
180
+
181
+ all_indices = np.array(all_indices)
182
+ all_indices_str = np.array(all_indices_str)
183
+
184
+ print("All indices number: ", len(all_indices))
185
+ print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
186
+
187
+ all_indices_dict = {}
188
+ for item, indices in enumerate(all_indices.tolist()):
189
+ all_indices_dict[item] = list(indices)
190
+
191
+ json_path = os.path.join(generate_args.user_model_path, 'indices.user.json')
192
+ with open(json_path, 'w', encoding = 'utf-8') as f:
193
+ json.dump(all_indices_dict, f)
index/datasets.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data as data
4
+ import pandas as pd
5
+ from tqdm import tqdm
6
+
7
+ class EmbDataset(data.Dataset):
8
+ def __init__(self,data_path):
9
+ self.data_path = data_path
10
+ names = ['emb']
11
+ usecols = [1]
12
+ tsv_data = pd.read_csv(data_path, sep = '\t',usecols = usecols, names = names, quotechar = None, quoting = 3)
13
+ features = tsv_data['emb'].values.tolist()
14
+ num_data = len(features)
15
+ for i in tqdm(range(num_data)):
16
+ features[i] = [float(s) for s in features[i].split(' ')]
17
+ self.embeddings = np.array(features, dtype = np.float16)
18
+ assert self.embeddings.shape[0] == num_data
19
+ self.dim = self.embeddings.shape[-1]
20
+
21
+ def __getitem__(self, index):
22
+ emb = self.embeddings[index]
23
+ tensor_emb = torch.tensor(emb, dtype = torch.float16)
24
+ return tensor_emb
25
+
26
+ def __len__(self):
27
+ return len(self.embeddings)
index/generate_indices.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import json
3
+ import logging
4
+ import argparse
5
+
6
+ import numpy as np
7
+ import torch
8
+ from time import time
9
+ from torch import optim
10
+ from tqdm import tqdm
11
+
12
+ from torch.utils.data import DataLoader
13
+
14
+ from datasets import EmbDataset
15
+ from models.rqvae import RQVAE
16
+
17
+ import os
18
+
19
+ def check_collision(all_indices_str):
20
+ tot_item = len(all_indices_str)
21
+ tot_indice = len(set(all_indices_str.tolist()))
22
+ return tot_item==tot_indice
23
+
24
+ def get_indices_count(all_indices_str):
25
+ indices_count = collections.defaultdict(int)
26
+ for index in all_indices_str:
27
+ indices_count[index] += 1
28
+ return indices_count
29
+
30
+ def get_collision_item(all_indices_str):
31
+ index2id = {}
32
+ for i, index in enumerate(all_indices_str):
33
+ if index not in index2id:
34
+ index2id[index] = []
35
+ index2id[index].append(i)
36
+
37
+ collision_item_groups = []
38
+
39
+ for index in index2id:
40
+ if len(index2id[index]) > 1:
41
+ collision_item_groups.append(index2id[index])
42
+
43
+ return collision_item_groups
44
+
45
+ def parse_args():
46
+ parser = argparse.ArgumentParser(description = "Index")
47
+
48
+ parser.add_argument("--data_path", type = str, default = "", help = "Infer data path.")
49
+ parser.add_argument("--ckpt_path", type=str, default="", help="model checkpoint for infer")
50
+ parser.add_argument("--id_save_path", type=str, default="", help="output directory for id result")
51
+ parser.add_argument("--device", type=str, default="cuda:0", help="gpu or cpu")
52
+
53
+ return parser.parse_args()
54
+
55
+ # dataset = "Games"
56
+ # ckpt_path = "/zhengbowen/rqvae_ckpt/xxxx"
57
+ # output_dir = f"/zhengbowen/data/{dataset}/"
58
+ # output_file = f"{dataset}.index.json"
59
+ # output_file = os.path.join(output_dir,output_file)
60
+
61
+ infer_args = parse_args()
62
+ print('infer_args:', infer_args)
63
+ device = torch.device(infer_args.device)
64
+ output_file = infer_args.id_save_path
65
+ data = EmbDataset(infer_args.data_path)
66
+
67
+ ckpt = torch.load(infer_args.ckpt_path, map_location = torch.device('cpu'))
68
+ args = ckpt["args"]
69
+ state_dict = ckpt["state_dict"]
70
+
71
+ model = RQVAE(in_dim=data.dim,
72
+ num_emb_list=args.num_emb_list,
73
+ e_dim=args.e_dim,
74
+ layers=args.layers,
75
+ dropout_prob=args.dropout_prob,
76
+ bn=args.bn,
77
+ loss_type=args.loss_type,
78
+ quant_loss_weight=args.quant_loss_weight,
79
+ kmeans_init=args.kmeans_init,
80
+ kmeans_iters=args.kmeans_iters,
81
+ sk_epsilons=args.sk_epsilons,
82
+ sk_iters=args.sk_iters,
83
+ )
84
+
85
+ model.load_state_dict(state_dict)
86
+ model = model.to(device)
87
+ model.eval()
88
+ print(model)
89
+
90
+ data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True)
91
+
92
+ all_indices = []
93
+ all_indices_str = []
94
+ prefix = ["<a-{}>","<b-{}>","<c-{}>","<d-{}>","<e-{}>"]
95
+
96
+ for d in tqdm(data_loader):
97
+ d = d.to(device)
98
+ indices = model.get_indices(d,use_sk = False)
99
+ indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
100
+ for index in indices:
101
+ code = []
102
+ for i, ind in enumerate(index):
103
+ code.append(prefix[i].format(int(ind)))
104
+
105
+ all_indices.append(code)
106
+ all_indices_str.append(str(code))
107
+
108
+ all_indices = np.array(all_indices)
109
+ all_indices_str = np.array(all_indices_str)
110
+
111
+ for vq in model.rq.vq_layers[:-1]:
112
+ vq.sk_epsilon = 0.0
113
+ if model.rq.vq_layers[-1].sk_epsilon == 0.0:
114
+ model.rq.vq_layers[-1].sk_epsilon = 0.003
115
+
116
+ tt = 0
117
+ #There are often duplicate items in the dataset, and we no longer differentiate them
118
+ while True:
119
+ if tt >= 20 or check_collision(all_indices_str):
120
+ break
121
+
122
+ collision_item_groups = get_collision_item(all_indices_str)
123
+ # print(collision_item_groups)
124
+ print(len(collision_item_groups))
125
+ for collision_items in collision_item_groups:
126
+ d = data[collision_items].to(device)
127
+
128
+ indices = model.get_indices(d, use_sk= True)
129
+ indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
130
+ for item, index in zip(collision_items, indices):
131
+ code = []
132
+ for i, ind in enumerate(index):
133
+ code.append(prefix[i].format(int(ind)))
134
+
135
+ all_indices[item] = code
136
+ all_indices_str[item] = str(code)
137
+ tt += 1
138
+
139
+ print("All indices number: ", len(all_indices))
140
+ print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
141
+
142
+ tot_item = len(all_indices_str)
143
+ tot_indice = len(set(all_indices_str.tolist()))
144
+ print("Collision Rate", (tot_item - tot_indice) / tot_item)
145
+
146
+ all_indices_dict = {}
147
+ for item, indices in enumerate(all_indices.tolist()):
148
+ all_indices_dict[item] = list(indices)
149
+
150
+ with open(output_file, 'w') as fp:
151
+ json.dump(all_indices_dict, fp)
index/main.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import torch
4
+ import numpy as np
5
+ from time import time
6
+ import logging
7
+
8
+ from torch.utils.data import DataLoader
9
+
10
+ from datasets import EmbDataset
11
+ from models.rqvae import RQVAE
12
+ from trainer import Trainer
13
+
14
+ def parse_args():
15
+ parser = argparse.ArgumentParser(description="Index")
16
+
17
+ parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
18
+ parser.add_argument('--epochs', type=int, default=5000, help='number of epochs')
19
+ parser.add_argument('--batch_size', type=int, default=1024, help='batch size')
20
+ parser.add_argument('--num_workers', type=int, default=4, )
21
+ parser.add_argument('--eval_step', type=int, default=50, help='eval step')
22
+ parser.add_argument('--learner', type=str, default="AdamW", help='optimizer')
23
+ parser.add_argument("--data_path", type=str,
24
+ default="../data/Games/Games.emb-llama-td.npy",
25
+ help="Input data path.")
26
+
27
+ parser.add_argument('--weight_decay', type=float, default=1e-4, help='l2 regularization weight')
28
+ parser.add_argument("--dropout_prob", type=float, default=0.0, help="dropout ratio")
29
+ parser.add_argument("--bn", type=bool, default=False, help="use bn or not")
30
+ parser.add_argument("--loss_type", type=str, default="mse", help="loss_type")
31
+ parser.add_argument("--kmeans_init", type=bool, default=True, help="use kmeans_init or not")
32
+ parser.add_argument("--kmeans_iters", type=int, default=100, help="max kmeans iters")
33
+ parser.add_argument('--sk_epsilons', type=float, nargs='+', default=[0.0, 0.0, 0.0], help="sinkhorn epsilons")
34
+ parser.add_argument("--sk_iters", type=int, default=50, help="max sinkhorn iters")
35
+
36
+ parser.add_argument("--device", type=str, default="cuda:1", help="gpu or cpu")
37
+
38
+ parser.add_argument('--num_emb_list', type=int, nargs='+', default=[256,256,256], help='emb num of every vq')
39
+ parser.add_argument('--e_dim', type=int, default=32, help='vq codebook embedding size')
40
+ parser.add_argument('--quant_loss_weight', type=float, default=1.0, help='vq quantion loss weight')
41
+ parser.add_argument('--layers', type=int, nargs='+', default=[2048,1024,512,256,128,64], help='hidden sizes of every layer')
42
+
43
+ parser.add_argument("--ckpt_dir", type=str, default="", help="output directory for model")
44
+
45
+ return parser.parse_args()
46
+
47
+
48
+ if __name__ == '__main__':
49
+ """fix the random seed"""
50
+ seed = 2023
51
+ random.seed(seed)
52
+ np.random.seed(seed)
53
+ torch.manual_seed(seed)
54
+ torch.cuda.manual_seed_all(seed)
55
+ torch.backends.cudnn.deterministic = True
56
+ torch.backends.cudnn.benchmark = False
57
+
58
+ args = parse_args()
59
+ print(args)
60
+
61
+ logging.basicConfig(level=logging.DEBUG)
62
+
63
+ """build dataset"""
64
+ data = EmbDataset(args.data_path)
65
+ model = RQVAE(in_dim=data.dim,
66
+ num_emb_list=args.num_emb_list,
67
+ e_dim=args.e_dim,
68
+ layers=args.layers,
69
+ dropout_prob=args.dropout_prob,
70
+ bn=args.bn,
71
+ loss_type=args.loss_type,
72
+ quant_loss_weight=args.quant_loss_weight,
73
+ kmeans_init=args.kmeans_init,
74
+ kmeans_iters=args.kmeans_iters,
75
+ sk_epsilons=args.sk_epsilons,
76
+ sk_iters=args.sk_iters,
77
+ )
78
+ print(model)
79
+ data_loader = DataLoader(data,num_workers=args.num_workers,
80
+ batch_size=args.batch_size, shuffle=True,
81
+ pin_memory=True)
82
+ trainer = Trainer(args,model)
83
+ best_loss, best_collision_rate = trainer.fit(data_loader)
84
+
85
+ print("Best Loss",best_loss)
86
+ print("Best Collision Rate", best_collision_rate)
87
+
index/models/layers.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.init import xavier_normal_
4
+ from sklearn.cluster import KMeans
5
+
6
+
7
+ class MLPLayers(nn.Module):
8
+
9
+ def __init__(
10
+ self, layers, dropout=0.0, activation="relu", bn=False
11
+ ):
12
+ super(MLPLayers, self).__init__()
13
+ self.layers = layers
14
+ self.dropout = dropout
15
+ self.activation = activation
16
+ self.use_bn = bn
17
+
18
+ mlp_modules = []
19
+ for idx, (input_size, output_size) in enumerate(
20
+ zip(self.layers[:-1], self.layers[1:])
21
+ ):
22
+ mlp_modules.append(nn.Dropout(p=self.dropout))
23
+ mlp_modules.append(nn.Linear(input_size, output_size))
24
+ if self.use_bn:
25
+ mlp_modules.append(nn.BatchNorm1d(num_features=output_size))
26
+ activation_func = activation_layer(self.activation, output_size)
27
+ if activation_func is not None and idx != (len(self.layers)-2):
28
+ mlp_modules.append(activation_func)
29
+
30
+ self.mlp_layers = nn.Sequential(*mlp_modules)
31
+ self.apply(self.init_weights)
32
+
33
+ def init_weights(self, module):
34
+ # We just initialize the module with normal distribution as the paper said
35
+ if isinstance(module, nn.Linear):
36
+ xavier_normal_(module.weight.data)
37
+ if module.bias is not None:
38
+ module.bias.data.fill_(0.0)
39
+
40
+ def forward(self, input_feature):
41
+ return self.mlp_layers(input_feature)
42
+
43
+ def activation_layer(activation_name="relu", emb_dim=None):
44
+
45
+ if activation_name is None:
46
+ activation = None
47
+ elif isinstance(activation_name, str):
48
+ if activation_name.lower() == "sigmoid":
49
+ activation = nn.Sigmoid()
50
+ elif activation_name.lower() == "tanh":
51
+ activation = nn.Tanh()
52
+ elif activation_name.lower() == "relu":
53
+ activation = nn.ReLU()
54
+ elif activation_name.lower() == "leakyrelu":
55
+ activation = nn.LeakyReLU()
56
+ elif activation_name.lower() == "none":
57
+ activation = None
58
+ elif issubclass(activation_name, nn.Module):
59
+ activation = activation_name()
60
+ else:
61
+ raise NotImplementedError(
62
+ "activation function {} is not implemented".format(activation_name)
63
+ )
64
+
65
+ return activation
66
+
67
+ def kmeans(
68
+ samples,
69
+ num_clusters,
70
+ num_iters = 10,
71
+ ):
72
+ B, dim, dtype, device = samples.shape[0], samples.shape[-1], samples.dtype, samples.device
73
+ x = samples.cpu().detach().numpy()
74
+
75
+ cluster = KMeans(n_clusters = num_clusters, max_iter = num_iters).fit(x)
76
+
77
+ centers = cluster.cluster_centers_
78
+ tensor_centers = torch.from_numpy(centers).to(device)
79
+
80
+ return tensor_centers
81
+
82
+
83
+ @torch.no_grad()
84
+ def sinkhorn_algorithm(distances, epsilon, sinkhorn_iterations):
85
+ Q = torch.exp(- distances / epsilon)
86
+
87
+ B = Q.shape[0] # number of samples to assign
88
+ K = Q.shape[1] # how many centroids per block (usually set to 256)
89
+
90
+ # make the matrix sums to 1
91
+ sum_Q = Q.sum(-1, keepdim=True).sum(-2, keepdim=True)
92
+ Q /= sum_Q
93
+ # print(Q.sum())
94
+ for it in range(sinkhorn_iterations):
95
+
96
+ # normalize each column: total weight per sample must be 1/B
97
+ Q /= torch.sum(Q, dim=1, keepdim=True)
98
+ Q /= B
99
+
100
+ # normalize each row: total weight per prototype must be 1/K
101
+ Q /= torch.sum(Q, dim=0, keepdim=True)
102
+ Q /= K
103
+
104
+
105
+ Q *= B # the colomns must sum to 1 so that Q is an assignment
106
+ return Q
index/models/rq.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .vq import VectorQuantizer
5
+
6
+
7
+ class ResidualVectorQuantizer(nn.Module):
8
+ """ References:
9
+ SoundStream: An End-to-End Neural Audio Codec
10
+ https://arxiv.org/pdf/2107.03312.pdf
11
+ """
12
+
13
+ def __init__(self, n_e_list, e_dim, sk_epsilons,
14
+ kmeans_init = False, kmeans_iters = 100, sk_iters=100,):
15
+ super().__init__()
16
+ self.n_e_list = n_e_list
17
+ self.e_dim = e_dim
18
+ self.num_quantizers = len(n_e_list)
19
+ self.kmeans_init = kmeans_init
20
+ self.kmeans_iters = kmeans_iters
21
+ self.sk_epsilons = sk_epsilons
22
+ self.sk_iters = sk_iters
23
+ self.vq_layers = nn.ModuleList([VectorQuantizer(n_e, e_dim,
24
+ kmeans_init = self.kmeans_init,
25
+ kmeans_iters = self.kmeans_iters,
26
+ sk_epsilon=sk_epsilon,
27
+ sk_iters=sk_iters)
28
+ for n_e, sk_epsilon in zip(n_e_list,sk_epsilons) ])
29
+
30
+ def get_codebook(self):
31
+ all_codebook = []
32
+ for quantizer in self.vq_layers:
33
+ codebook = quantizer.get_codebook()
34
+ all_codebook.append(codebook)
35
+ return torch.stack(all_codebook)
36
+
37
+ def forward(self, x, use_sk=True):
38
+ all_losses = []
39
+ all_indices = []
40
+
41
+ x_q = 0
42
+ residual = x
43
+ for quantizer in self.vq_layers:
44
+ x_res, loss, indices = quantizer(residual, use_sk=use_sk)
45
+ residual = residual - x_res
46
+ x_q = x_q + x_res
47
+
48
+ all_losses.append(loss)
49
+ all_indices.append(indices)
50
+
51
+ mean_losses = torch.stack(all_losses).mean()
52
+ all_indices = torch.stack(all_indices, dim=-1)
53
+
54
+ return x_q, mean_losses, all_indices
index/models/rqvae.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from .layers import MLPLayers
7
+ from .rq import ResidualVectorQuantizer
8
+
9
+
10
+ class RQVAE(nn.Module):
11
+ def __init__(self,
12
+ in_dim=768,
13
+ # num_emb_list=[256,256,256,256],
14
+ num_emb_list=None,
15
+ e_dim=64,
16
+ # layers=[512,256,128],
17
+ layers=None,
18
+ dropout_prob=0.0,
19
+ bn=False,
20
+ loss_type="mse",
21
+ quant_loss_weight=1.0,
22
+ kmeans_init=False,
23
+ kmeans_iters=100,
24
+ # sk_epsilons=[0,0,0.003,0.01]],
25
+ sk_epsilons=None,
26
+ sk_iters=100,
27
+ ):
28
+ super(RQVAE, self).__init__()
29
+
30
+ self.in_dim = in_dim
31
+ self.num_emb_list = num_emb_list
32
+ self.e_dim = e_dim
33
+
34
+ self.layers = layers
35
+ self.dropout_prob = dropout_prob
36
+ self.bn = bn
37
+ self.loss_type = loss_type
38
+ self.quant_loss_weight=quant_loss_weight
39
+ self.kmeans_init = kmeans_init
40
+ self.kmeans_iters = kmeans_iters
41
+ self.sk_epsilons = sk_epsilons
42
+ self.sk_iters = sk_iters
43
+
44
+ self.encode_layer_dims = [self.in_dim] + self.layers + [self.e_dim]
45
+ self.encoder = MLPLayers(layers=self.encode_layer_dims,
46
+ dropout=self.dropout_prob,bn=self.bn)
47
+
48
+ self.rq = ResidualVectorQuantizer(num_emb_list, e_dim,
49
+ kmeans_init = self.kmeans_init,
50
+ kmeans_iters = self.kmeans_iters,
51
+ sk_epsilons=self.sk_epsilons,
52
+ sk_iters=self.sk_iters,)
53
+
54
+ self.decode_layer_dims = self.encode_layer_dims[::-1]
55
+ self.decoder = MLPLayers(layers=self.decode_layer_dims,
56
+ dropout=self.dropout_prob,bn=self.bn)
57
+
58
+ def forward(self, x, use_sk=True):
59
+ # print('x.shape:',x.shape)
60
+ x = self.encoder(x)
61
+ x_q, rq_loss, indices = self.rq(x,use_sk=use_sk)
62
+ out = self.decoder(x_q)
63
+ # print('out.shape:',out.shape)
64
+
65
+ return out, rq_loss, indices
66
+
67
+ @torch.no_grad()
68
+ def get_indices(self, xs, use_sk=False):
69
+ x_e = self.encoder(xs)
70
+ _, _, indices = self.rq(x_e, use_sk=use_sk)
71
+ return indices
72
+
73
+ def compute_loss(self, out, quant_loss, xs=None):
74
+
75
+ if self.loss_type == 'mse':
76
+ loss_recon = F.mse_loss(out, xs, reduction='mean')
77
+ elif self.loss_type == 'l1':
78
+ loss_recon = F.l1_loss(out, xs, reduction='mean')
79
+ else:
80
+ raise ValueError('incompatible loss type')
81
+
82
+ loss_total = loss_recon + self.quant_loss_weight * quant_loss
83
+
84
+ return loss_total, loss_recon
index/models/vq.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .layers import kmeans, sinkhorn_algorithm
5
+
6
+
7
+ class VectorQuantizer(nn.Module):
8
+
9
+ def __init__(self, n_e, e_dim,
10
+ beta = 0.25, kmeans_init = False, kmeans_iters = 10,
11
+ sk_epsilon=0.01, sk_iters=100):
12
+ super().__init__()
13
+ self.n_e = n_e
14
+ self.e_dim = e_dim
15
+ self.beta = beta
16
+ self.kmeans_init = kmeans_init
17
+ self.kmeans_iters = kmeans_iters
18
+ self.sk_epsilon = sk_epsilon
19
+ self.sk_iters = sk_iters
20
+
21
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
22
+ if not kmeans_init:
23
+ self.initted = True
24
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
25
+ else:
26
+ self.initted = False
27
+ self.embedding.weight.data.zero_()
28
+
29
+ def get_codebook(self):
30
+ return self.embedding.weight
31
+
32
+ def get_codebook_entry(self, indices, shape=None):
33
+ # get quantized latent vectors
34
+ z_q = self.embedding(indices)
35
+ if shape is not None:
36
+ z_q = z_q.view(shape)
37
+
38
+ return z_q
39
+
40
+ def init_emb(self, data):
41
+
42
+ centers = kmeans(
43
+ data,
44
+ self.n_e,
45
+ self.kmeans_iters,
46
+ )
47
+
48
+ self.embedding.weight.data.copy_(centers)
49
+ self.initted = True
50
+
51
+ @staticmethod
52
+ def center_distance_for_constraint(distances):
53
+ # distances: B, K
54
+ max_distance = distances.max()
55
+ min_distance = distances.min()
56
+
57
+ middle = (max_distance + min_distance) / 2
58
+ amplitude = max_distance - middle + 1e-5
59
+ assert amplitude > 0
60
+ centered_distances = (distances - middle) / amplitude
61
+ return centered_distances
62
+
63
+ def forward(self, x, use_sk=True):
64
+ # Flatten input
65
+ latent = x.view(-1, self.e_dim)
66
+
67
+ if not self.initted and self.training:
68
+ self.init_emb(latent)
69
+
70
+ # Calculate the L2 Norm between latent and Embedded weights
71
+ d = torch.sum(latent**2, dim=1, keepdim=True) + \
72
+ torch.sum(self.embedding.weight**2, dim=1, keepdim=True).t()- \
73
+ 2 * torch.matmul(latent, self.embedding.weight.t())
74
+ if not use_sk or self.sk_epsilon <= 0:
75
+ indices = torch.argmin(d, dim=-1)
76
+ # print("=======",self.sk_epsilon)
77
+ else:
78
+ # print("++++++++",self.sk_epsilon)
79
+ d = self.center_distance_for_constraint(d)
80
+ d = d.double()
81
+ Q = sinkhorn_algorithm(d,self.sk_epsilon,self.sk_iters)
82
+ # print(Q.sum(0)[:10])
83
+ Q = torch.nan_to_num(Q, Q[torch.isfinite(Q)].min().item())
84
+ if torch.isnan(Q).any() or torch.isinf(Q).any():
85
+ print(f"Sinkhorn Algorithm returns nan/inf values.")
86
+ indices = torch.argmax(Q, dim=-1)
87
+
88
+ # indices = torch.argmin(d, dim=-1)
89
+
90
+ x_q = self.embedding(indices).view(x.shape)
91
+
92
+ # compute loss for embedding
93
+ commitment_loss = F.mse_loss(x_q.detach(), x)
94
+ codebook_loss = F.mse_loss(x_q, x.detach())
95
+ loss = codebook_loss + self.beta * commitment_loss
96
+
97
+ # preserve gradients
98
+ x_q = x + (x_q - x).detach()
99
+
100
+ indices = indices.view(x.shape[:-1])
101
+
102
+ return x_q, loss, indices
103
+
104
+
index/run.sh ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+
2
+ python -u main.py \
3
+ --num_emb_list 256 256 256 256 \
4
+ --sk_epsilons 0.0 0.0 0.0 0.003 \
5
+ --device cuda:0 \
6
+ --data_path /data/Games/Games.emb-llama-td.npy \
7
+ --batch_size 1024
8
+
index/trainer.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import numpy as np
4
+ import torch
5
+ from time import time
6
+ from torch import optim
7
+ from tqdm import tqdm
8
+
9
+ from utils import ensure_dir,set_color,get_local_time
10
+ import os
11
+
12
+ class Trainer(object):
13
+
14
+ def __init__(self, args, model):
15
+ self.args = args
16
+ self.model = model
17
+ self.logger = logging.getLogger()
18
+
19
+ self.lr = args.lr
20
+ self.learner = args.learner
21
+ self.weight_decay = args.weight_decay
22
+ self.epochs = args.epochs
23
+ self.eval_step = min(args.eval_step, self.epochs)
24
+ self.device = args.device
25
+ self.device = torch.device(self.device)
26
+ self.ckpt_dir = args.ckpt_dir
27
+ saved_model_dir = "{}".format(get_local_time())
28
+ self.ckpt_dir = os.path.join(self.ckpt_dir,saved_model_dir)
29
+ ensure_dir(self.ckpt_dir)
30
+
31
+ self.best_loss = np.inf
32
+ self.best_collision_rate = np.inf
33
+ self.best_loss_ckpt = "best_loss_model.pth"
34
+ self.best_collision_ckpt = "best_collision_model.pth"
35
+ self.optimizer = self._build_optimizer()
36
+ self.model = self.model.to(self.device)
37
+
38
+ def _build_optimizer(self):
39
+
40
+ params = self.model.parameters()
41
+ learner = self.learner
42
+ learning_rate = self.lr
43
+ weight_decay = self.weight_decay
44
+
45
+ if learner.lower() == "adam":
46
+ optimizer = optim.Adam(params, lr=learning_rate, weight_decay=weight_decay)
47
+ elif learner.lower() == "sgd":
48
+ optimizer = optim.SGD(params, lr=learning_rate, weight_decay=weight_decay)
49
+ elif learner.lower() == "adagrad":
50
+ optimizer = optim.Adagrad(
51
+ params, lr=learning_rate, weight_decay=weight_decay
52
+ )
53
+ for state in optimizer.state.values():
54
+ for k, v in state.items():
55
+ if torch.is_tensor(v):
56
+ state[k] = v.to(self.device)
57
+ elif learner.lower() == "rmsprop":
58
+ optimizer = optim.RMSprop(
59
+ params, lr=learning_rate, weight_decay=weight_decay
60
+ )
61
+ elif learner.lower() == 'adamw':
62
+ optimizer = optim.AdamW(
63
+ params, lr=learning_rate, weight_decay=weight_decay
64
+ )
65
+ else:
66
+ self.logger.warning(
67
+ "Received unrecognized optimizer, set default Adam optimizer"
68
+ )
69
+ optimizer = optim.Adam(params, lr=learning_rate)
70
+ return optimizer
71
+ def _check_nan(self, loss):
72
+ if torch.isnan(loss):
73
+ raise ValueError("Training loss is nan")
74
+
75
+ def _train_epoch(self, train_data, epoch_idx):
76
+
77
+ self.model.train()
78
+
79
+ total_loss = 0
80
+ total_recon_loss = 0
81
+ iter_data = tqdm(
82
+ train_data,
83
+ total=len(train_data),
84
+ ncols=100,
85
+ desc=set_color(f"Train {epoch_idx}","pink"),
86
+ )
87
+
88
+ for batch_idx, data in enumerate(iter_data):
89
+ data = data.to(self.device)
90
+ self.optimizer.zero_grad()
91
+ out, rq_loss, indices = self.model(data)
92
+ loss, loss_recon = self.model.compute_loss(out, rq_loss, xs=data)
93
+ self._check_nan(loss)
94
+ loss.backward()
95
+ self.optimizer.step()
96
+ total_loss += loss.item()
97
+ total_recon_loss += loss_recon.item()
98
+
99
+ return total_loss, total_recon_loss
100
+
101
+ @torch.no_grad()
102
+ def _valid_epoch(self, valid_data):
103
+
104
+ self.model.eval()
105
+
106
+ iter_data =tqdm(
107
+ valid_data,
108
+ total=len(valid_data),
109
+ ncols=100,
110
+ desc=set_color(f"Evaluate ", "pink"),
111
+ )
112
+ indices_set = set()
113
+ num_sample = 0
114
+ for batch_idx, data in enumerate(iter_data):
115
+ num_sample += len(data)
116
+ data = data.to(self.device)
117
+ indices = self.model.get_indices(data)
118
+ indices = indices.view(-1,indices.shape[-1]).cpu().numpy()
119
+ for index in indices:
120
+ code = "-".join([str(int(_)) for _ in index])
121
+ indices_set.add(code)
122
+
123
+ collision_rate = (num_sample - len(indices_set))/num_sample
124
+
125
+ return collision_rate
126
+
127
+ def _save_checkpoint(self, epoch, collision_rate=1, ckpt_file=None):
128
+
129
+ ckpt_path = os.path.join(self.ckpt_dir,ckpt_file) if ckpt_file \
130
+ else os.path.join(self.ckpt_dir, 'epoch_%d_collision_%.4f_model.pth' % (epoch, collision_rate))
131
+ state = {
132
+ "args": self.args,
133
+ "epoch": epoch,
134
+ "best_loss": self.best_loss,
135
+ "best_collision_rate": self.best_collision_rate,
136
+ "state_dict": self.model.state_dict(),
137
+ "optimizer": self.optimizer.state_dict(),
138
+ }
139
+ torch.save(state, ckpt_path, pickle_protocol=4)
140
+
141
+ self.logger.info(
142
+ set_color("Saving current", "blue") + f": {ckpt_path}"
143
+ )
144
+
145
+ def _generate_train_loss_output(self, epoch_idx, s_time, e_time, loss, recon_loss):
146
+ train_loss_output = (
147
+ set_color("epoch %d training", "green")
148
+ + " ["
149
+ + set_color("time", "blue")
150
+ + ": %.2fs, "
151
+ ) % (epoch_idx, e_time - s_time)
152
+ train_loss_output += set_color("train loss", "blue") + ": %.4f" % loss
153
+ train_loss_output +=", "
154
+ train_loss_output += set_color("reconstruction loss", "blue") + ": %.4f" % recon_loss
155
+ return train_loss_output + "]"
156
+
157
+
158
+ def fit(self, data):
159
+
160
+ cur_eval_step = 0
161
+
162
+ for epoch_idx in range(self.epochs):
163
+ # train
164
+ training_start_time = time()
165
+ train_loss, train_recon_loss = self._train_epoch(data, epoch_idx)
166
+ training_end_time = time()
167
+ train_loss_output = self._generate_train_loss_output(
168
+ epoch_idx, training_start_time, training_end_time, train_loss, train_recon_loss
169
+ )
170
+ self.logger.info(train_loss_output)
171
+
172
+ if train_loss < self.best_loss:
173
+ self.best_loss = train_loss
174
+ # self._save_checkpoint(epoch=epoch_idx,ckpt_file=self.best_loss_ckpt)
175
+
176
+ # eval
177
+ if (epoch_idx + 1) % self.eval_step == 0:
178
+ valid_start_time = time()
179
+ collision_rate = self._valid_epoch(data)
180
+
181
+ if collision_rate < self.best_collision_rate:
182
+ self.best_collision_rate = collision_rate
183
+ cur_eval_step = 0
184
+ self._save_checkpoint(epoch_idx, collision_rate=collision_rate,
185
+ ckpt_file=self.best_collision_ckpt)
186
+ else:
187
+ cur_eval_step += 1
188
+
189
+
190
+ valid_end_time = time()
191
+ valid_score_output = (
192
+ set_color("epoch %d evaluating", "green")
193
+ + " ["
194
+ + set_color("time", "blue")
195
+ + ": %.2fs, "
196
+ + set_color("collision_rate", "blue")
197
+ + ": %f]"
198
+ ) % (epoch_idx, valid_end_time - valid_start_time, collision_rate)
199
+
200
+ self.logger.info(valid_score_output)
201
+ if epoch_idx>1000:
202
+ self._save_checkpoint(epoch_idx, collision_rate=collision_rate)
203
+
204
+
205
+ return self.best_loss, self.best_collision_rate
206
+
207
+
208
+
209
+
index/utils.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import datetime
3
+ import os
4
+
5
+
6
+ def ensure_dir(dir_path):
7
+
8
+ os.makedirs(dir_path, exist_ok=True)
9
+
10
+ def set_color(log, color, highlight=True):
11
+ color_set = ["black", "red", "green", "yellow", "blue", "pink", "cyan", "white"]
12
+ try:
13
+ index = color_set.index(color)
14
+ except:
15
+ index = len(color_set) - 1
16
+ prev_log = "\033["
17
+ if highlight:
18
+ prev_log += "1;3"
19
+ else:
20
+ prev_log += "0;3"
21
+ prev_log += str(index) + "m"
22
+ return prev_log + log + "\033[0m"
23
+
24
+ def get_local_time():
25
+ r"""Get current time
26
+
27
+ Returns:
28
+ str: current time
29
+ """
30
+ cur = datetime.datetime.now()
31
+ cur = cur.strftime("%b-%d-%Y_%H-%M-%S")
32
+
33
+ return cur
34
+
35
+
36
+
infer.sh ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CKPT_PATH=${datain}/v-yinju/rq-llama/v6/Instruments
2
+
3
+ python generate_embeddings.py \
4
+ --ckpt_path $CKPT_PATH \
5
+ --item_save_path $CKPT_PATH/embeddings.item.tsv \
6
+ --user_save_path $CKPT_PATH/embeddings.user.tsv \
7
+ --device_map 0
8
+
9
+ python generate_indices.py \
10
+ --ckpt_path $CKPT_PATH \
11
+ --item_data_path $CKPT_PATH/embeddings.item.tsv \
12
+ --user_data_path $CKPT_PATH/embeddings.user.tsv \
13
+ --save_path $CKPT_PATH \
14
+ --device_map 0
instruments_evaluate.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DATASET=Instruments
2
+ BASE_MODEL=/datain/v-yinju/llama-7b
3
+ DATA_PATH=/datain/v-yinju/rqvae-zzx/data
4
+ CKPT_PATH=/datain/v-yinju/rq-llama/v11.2/Ins/finetune
5
+ RESULTS_FILE=$CKPT_PATH/eval_result.json
6
+
7
+ torchrun --nproc_per_node=8 evaluate-finetuned.py \
8
+ --base_model $BASE_MODEL \
9
+ --ckpt_path $CKPT_PATH \
10
+ --dataset $DATASET \
11
+ --data_path $DATA_PATH \
12
+ --results_file $RESULTS_FILE \
13
+ --test_batch_size 1 \
14
+ --num_beams 20 \
15
+ --test_prompt_ids all \
16
+ --test_task seqrec \
17
+ --index_file /datain/v-yinju/rq-llama/v11.2/Ins/indices.item.json \
18
+ --user_index_file /datain/v-yinju/rq-llama/v11.2/Ins/indices.user.json
instruments_finetune.sh ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export WANDB_MODE=disabled
2
+ export CUDA_LAUNCH_BLOCKING=0
3
+
4
+ DATASET=Instruments
5
+ CKPT_PATH=/datain/v-yinju/rq-llama/v11/Instruments
6
+ DATA_PATH=/datain/v-yinju/rqvae-zzx/data
7
+ OUTPUT_DIR=$CKPT_PATH/finetune
8
+
9
+ torchrun --nproc_per_node=8 fine-tune.py \
10
+ --ckpt_path $CKPT_PATH \
11
+ --output_dir $OUTPUT_DIR \
12
+ --dataset $DATASET \
13
+ --data_path $DATA_PATH \
14
+ --per_device_batch_size 6 \
15
+ --gradient_accumulation_steps 2 \
16
+ --learning_rate 5e-5 \
17
+ --epochs 4 \
18
+ --weight_decay 0.01 \
19
+ --save_and_eval_strategy epoch \
20
+ --fp16 \
21
+ --deepspeed ./config/ds_z2_fp16.json \
22
+ --dataloader_num_workers 4 \
23
+ --only_train_response \
24
+ --tasks seqrec,itemsearch,preferenceobtain,item2index,index2item,fusionseqrec,usersearch,user2pref,pref2user \
25
+ --train_prompt_sample_num 1,1,1,1,1,1,1,1,1 \
26
+ --train_data_sample_num 0,0,0,0,0,0,0,0,0 \
27
+ --index_file $CKPT_PATH/indices.item.json \
28
+ --user_index_file $CKPT_PATH/indices.user.json \
29
+ --reindex 17
30
+
31
+ cd convert
32
+ nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
33
+ cd ..
instruments_more_pretrain.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export WANDB_MODE=disabled
2
+ export CUDA_LAUNCH_BLOCKING=0
3
+
4
+ DATASET=Instruments
5
+ BASE_MODEL=/datain/v-yinju/llama-7b
6
+ CKPT_PATH=/datain/v-yinju/rq-llama/v6/Instruments
7
+ DATA_PATH=/datain/v-yinju/rqvae-zzx/data
8
+ OUTPUT_DIR=/datain/v-yinju/rq-llama/v3-train/Instruments/more_pretrain
9
+
10
+ torchrun --nproc_per_node=8 --master_port=3324 continue_pretrain.py \
11
+ --base_model $BASE_MODEL \
12
+ --ckpt_path $CKPT_PATH \
13
+ --output_dir $OUTPUT_DIR \
14
+ --dataset $DATASET \
15
+ --data_path $DATA_PATH \
16
+ --per_device_batch_size 6 \
17
+ --gradient_accumulation_steps 2 \
18
+ --learning_rate 5e-5 \
19
+ --epochs 4 \
20
+ --weight_decay 0.01 \
21
+ --save_and_eval_strategy epoch \
22
+ --deepspeed ./config/ds_z2_fp16.json \
23
+ --dataloader_num_workers 4 \
24
+ --only_train_response \
25
+ --tasks seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item \
26
+ --train_prompt_sample_num 1,1,1,1,1,1,1,1,1 \
27
+ --train_data_sample_num 0,0,0,0,0,0,0,0,0 \
28
+ --fp16 &>>$OUTPUT_DIR/pretrain-log.txt
29
+
30
+ cd convert
31
+ nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
32
+ cd ..
instruments_pretrain.sh ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export WANDB_MODE=disabled
2
+ export CUDA_LAUNCH_BLOCKING=0
3
+
4
+ DATASET=Instruments
5
+ BASE_MODEL=$datain/v-yinju/llama-7b
6
+ ITEM_MODEL=$datain/v-yinju/rqvae-zzx/models/instruments/Apr-01-2024_01-25-11/best_collision_model.pth
7
+ USER_MODEL=$datain/v-yinju/rqvae-zzx/models/instruments/user/Apr-23-2024_03-36-04/best_collision_model.pth
8
+ DATA_PATH=$datain/v-yinju/rqvae-zzx/data
9
+ OUTPUT_DIR=$datain/v-yinju/rq-llama/v11.2/Ins
10
+
11
+ torchrun --nproc_per_node=8 pre-train.py \
12
+ --base_model $BASE_MODEL \
13
+ --item_model $ITEM_MODEL \
14
+ --user_model $USER_MODEL \
15
+ --output_dir $OUTPUT_DIR \
16
+ --dataset $DATASET \
17
+ --data_path $DATA_PATH \
18
+ --per_device_batch_size 6 \
19
+ --gradient_accumulation_steps 2 \
20
+ --learning_rate 5e-4 \
21
+ --epochs 4 \
22
+ --weight_decay 0.01 \
23
+ --save_and_eval_strategy epoch \
24
+ --deepspeed ./config/ds_z2_fp16.json \
25
+ --dataloader_num_workers 4 \
26
+ --only_train_response \
27
+ --tasks seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item,usersearch,user2pref,pref2user \
28
+ --train_prompt_sample_num 1,1,1,1,1,1,1,1,1,1,1,1 \
29
+ --train_data_sample_num 0,0,0,0,0,0,0,0,0,0,0,0 \
30
+ --index_file .index.json \
31
+ --user_index_file .user-index.json \
32
+ --fp16
33
+
34
+ cd convert
35
+ nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
36
+ cd ..
37
+
38
+ CKPT_PATH=$datain/v-yinju/rq-llama/v11.2/Ins
39
+
40
+ python generate_embeddings.py \
41
+ --ckpt_path $CKPT_PATH \
42
+ --item_save_path $CKPT_PATH/embeddings.item.tsv \
43
+ --user_save_path $CKPT_PATH/embeddings.user.tsv \
44
+ --device_map 0
45
+
46
+ python generate_indices.py \
47
+ --ckpt_path $CKPT_PATH \
48
+ --item_data_path $CKPT_PATH/embeddings.item.tsv \
49
+ --user_data_path $CKPT_PATH/embeddings.user.tsv \
50
+ --save_path $CKPT_PATH \
51
+ --device_map 0
52
+
53
+ # DATASET=Games
54
+ # BASE_MODEL=/datain/v-yinju/llama-7b
55
+ # ITEM_MODEL=/datain/v-yinju/rqvae-zzx/models/games/Apr-18-2024_01-51-46/best_collision_model.pth
56
+ # USER_MODEL=/datain/v-yinju/rqvae-zzx/models/games/user/Jun-17-2024_18-40-36/best_collision_model.pth
57
+ # DATA_PATH=/datain/v-yinju/rqvae-zzx/data
58
+ # OUTPUT_DIR=/datain/v-yinju/rq-llama/v11/Games
59
+
60
+ # torchrun --nproc_per_node=8 pre-train.py \
61
+ # --base_model $BASE_MODEL \
62
+ # --item_model $ITEM_MODEL \
63
+ # --user_model $USER_MODEL \
64
+ # --output_dir $OUTPUT_DIR \
65
+ # --dataset $DATASET \
66
+ # --data_path $DATA_PATH \
67
+ # --per_device_batch_size 6 \
68
+ # --gradient_accumulation_steps 2 \
69
+ # --learning_rate 5e-5 \
70
+ # --epochs 4 \
71
+ # --weight_decay 0.01 \
72
+ # --save_and_eval_strategy epoch \
73
+ # --deepspeed ./config/ds_z2_fp16.json \
74
+ # --dataloader_num_workers 4 \
75
+ # --only_train_response \
76
+ # --tasks seqrec,itemsearch,inters2title,inters2description,preferenceobtain,item2index,index2item,intertitles2item,query2item,usersearch,user2pref,pref2user \
77
+ # --train_prompt_sample_num 1,1,1,1,1,1,1,1,1,1,1,1 \
78
+ # --train_data_sample_num 0,0,0,0,0,0,0,0,0,0,0,0 \
79
+ # --index_file .index.json \
80
+ # --user_index_file .user-index.json \
81
+ # --fp16
82
+
83
+ # cd convert
84
+ # nohup ./convert.sh $OUTPUT_DIR >convert.log 2>&1 &
85
+ # cd ..
lora_finetune.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import sys
4
+ from typing import List
5
+
6
+ import torch
7
+ import transformers
8
+
9
+
10
+ from peft import (
11
+ TaskType,
12
+ LoraConfig,
13
+ get_peft_model,
14
+ get_peft_model_state_dict,
15
+ set_peft_model_state_dict,
16
+ )
17
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
18
+
19
+ from utils import *
20
+ from collator import Collator
21
+
22
+ def train(args):
23
+
24
+ set_seed(args.seed)
25
+ ensure_dir(args.output_dir)
26
+
27
+ device_map = "auto"
28
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
29
+ ddp = world_size != 1
30
+ local_rank = int(os.environ.get("LOCAL_RANK") or 0)
31
+ if local_rank == 0:
32
+ print(vars(args))
33
+
34
+ if ddp:
35
+ device_map = {"": local_rank}
36
+
37
+ config = LlamaConfig.from_pretrained(args.base_model)
38
+ tokenizer = LlamaTokenizer.from_pretrained(
39
+ args.base_model,
40
+ model_max_length=args.model_max_length,
41
+ padding_side="right",
42
+ )
43
+ tokenizer.pad_token_id = 0
44
+
45
+ train_data, valid_data = load_datasets(args)
46
+ add_num = tokenizer.add_tokens(train_data.datasets[0].get_new_tokens())
47
+ config.vocab_size = len(tokenizer)
48
+ if local_rank == 0:
49
+ print("add {} new token.".format(add_num))
50
+ print("data num:", len(train_data))
51
+ tokenizer.save_pretrained(args.output_dir)
52
+ config.save_pretrained(args.output_dir)
53
+
54
+ collator = Collator(args, tokenizer)
55
+
56
+ model = LlamaForCausalLM.from_pretrained(
57
+ args.base_model,
58
+ # torch_dtype=torch.float16,
59
+ device_map=device_map,
60
+ )
61
+ model.resize_token_embeddings(len(tokenizer))
62
+
63
+ config = LoraConfig(
64
+ r=args.lora_r,
65
+ lora_alpha=args.lora_alpha,
66
+ target_modules=args.lora_target_modules.split(","),
67
+ modules_to_save=args.lora_modules_to_save.split(","),
68
+ lora_dropout=args.lora_dropout,
69
+ bias="none",
70
+ inference_mode=False,
71
+ task_type=TaskType.CAUSAL_LM,
72
+ )
73
+ model = get_peft_model(model, config)
74
+
75
+ if args.resume_from_checkpoint:
76
+ checkpoint_name = os.path.join(
77
+ args.resume_from_checkpoint, "adapter_model.bin"
78
+ ) # only LoRA model - LoRA config above has to fit
79
+ args.resume_from_checkpoint = False # So the trainer won't try loading its state
80
+ # The two files above have a different name depending on how they were saved, but are actually the same.
81
+ if os.path.exists(checkpoint_name):
82
+ if local_rank == 0:
83
+ print(f"Restarting from {checkpoint_name}")
84
+ adapters_weights = torch.load(checkpoint_name)
85
+ model = set_peft_model_state_dict(model, adapters_weights)
86
+ else:
87
+ if local_rank == 0:
88
+ print(f"Checkpoint {checkpoint_name} not found")
89
+
90
+ for n, p in model.named_parameters():
91
+ if "original_module" in n and any(module_name in n for module_name in config.modules_to_save):
92
+ p.requires_grad = False
93
+
94
+ if local_rank == 0:
95
+ model.print_trainable_parameters()
96
+
97
+
98
+ if not ddp and torch.cuda.device_count() > 1:
99
+ model.is_parallelizable = True
100
+ model.model_parallel = True
101
+
102
+ trainer = transformers.Trainer(
103
+ model=model,
104
+ train_dataset=train_data,
105
+ eval_dataset=valid_data,
106
+ args=transformers.TrainingArguments(
107
+ seed=args.seed,
108
+ per_device_train_batch_size=args.per_device_batch_size,
109
+ per_device_eval_batch_size=args.per_device_batch_size,
110
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
111
+ warmup_ratio=args.warmup_ratio,
112
+ num_train_epochs=args.epochs,
113
+ learning_rate=args.learning_rate,
114
+ weight_decay=args.weight_decay,
115
+ lr_scheduler_type=args.lr_scheduler_type,
116
+ fp16=args.fp16,
117
+ bf16=args.bf16,
118
+ logging_steps=args.logging_step,
119
+ optim=args.optim,
120
+ gradient_checkpointing=True,
121
+ evaluation_strategy=args.save_and_eval_strategy,
122
+ save_strategy=args.save_and_eval_strategy,
123
+ eval_steps=args.save_and_eval_steps,
124
+ save_steps=args.save_and_eval_steps,
125
+ output_dir=args.output_dir,
126
+ save_total_limit=5,
127
+ load_best_model_at_end=True,
128
+ deepspeed=args.deepspeed,
129
+ ddp_find_unused_parameters=False if ddp else None,
130
+ report_to=None,
131
+ eval_delay=1 if args.save_and_eval_strategy=="epoch" else 2000,
132
+ ),
133
+ tokenizer=tokenizer,
134
+ data_collator=collator,
135
+ )
136
+ model.config.use_cache = False
137
+
138
+ # old_state_dict = model.state_dict
139
+ # model.state_dict = (
140
+ # lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
141
+ # ).__get__(model, type(model))
142
+
143
+ if torch.__version__ >= "2" and sys.platform != "win32":
144
+ model = torch.compile(model)
145
+
146
+ trainer.train(
147
+ resume_from_checkpoint=args.resume_from_checkpoint,
148
+ )
149
+
150
+ trainer.save_state()
151
+ trainer.save_model(output_dir=args.output_dir)
152
+
153
+
154
+ if __name__ == "__main__":
155
+ parser = argparse.ArgumentParser(description='LLMRec')
156
+ parser = parse_global_args(parser)
157
+ parser = parse_train_args(parser)
158
+ parser = parse_dataset_args(parser)
159
+
160
+ args = parser.parse_args()
161
+
162
+ train(args)
pre-train.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import List
4
+ import argparse
5
+
6
+ import wandb
7
+ import torch
8
+ import transformers
9
+ from transformers import LlamaForCausalLM, LlamaTokenizer, LlamaConfig
10
+
11
+ from peft import (
12
+ TaskType,
13
+ LoraConfig,
14
+ get_peft_model,
15
+ get_peft_model_state_dict,
16
+ set_peft_model_state_dict,
17
+ )
18
+
19
+ from collator import VanillaCollator
20
+ from rq_llama import *
21
+ from utils import *
22
+
23
+ parser = argparse.ArgumentParser(description = 'rqllama-pretrain')
24
+ parser = parse_global_args(parser)
25
+ parser = parse_train_args(parser)
26
+ parser = parse_dataset_args(parser)
27
+ parser = parse_rqvae_args(parser)
28
+ args = parser.parse_args()
29
+ wandb.init(config = args, reinit = True)
30
+
31
+ set_seed(args.seed)
32
+ ensure_dir(args.output_dir)
33
+
34
+ device_map = "auto"
35
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
36
+ ddp = world_size != 1
37
+ local_rank = int(os.environ.get("LOCAL_RANK") or 0)
38
+ if local_rank == 0:
39
+ print(vars(args))
40
+ if ddp:
41
+ device_map = {"": local_rank}
42
+
43
+ train_data, valid_data = load_datasets(args)
44
+
45
+ config = LlamaConfig.from_pretrained(args.base_model)
46
+ config.args = vars(args)
47
+ rqllama = LlamaWithRQ(config)
48
+
49
+ ckpt = torch.load(args.item_model, map_location = torch.device('cpu'))
50
+ state_dict = ckpt["state_dict"]
51
+ rqllama.item_rqvae.load_state_dict(state_dict)
52
+ for i in range(len(args.num_emb_list)):
53
+ rqllama.item_rqvae.rq.vq_layers[i].initted = True
54
+ ckpt = torch.load(args.user_model, map_location = torch.device('cpu'))
55
+ state_dict = ckpt["state_dict"]
56
+ rqllama.user_rqvae.load_state_dict(state_dict)
57
+ for i in range(len(args.num_emb_list)):
58
+ rqllama.user_rqvae.rq.vq_layers[i].initted = True
59
+
60
+ if local_rank == 0:
61
+ print("token num:", len(rqllama.tokenizer))
62
+ print("data num:", len(train_data))
63
+ rqllama.tokenizer.save_pretrained(args.output_dir)
64
+ rqllama.config.save_pretrained(args.output_dir)
65
+
66
+ if args.resume_from_checkpoint:
67
+ checkpoint_name = os.path.join(args.resume_from_checkpoint, "adapter_model.bin")
68
+ args.resume_from_checkpoint = False
69
+ if os.path.exists(checkpoint_name):
70
+ if local_rank == 0:
71
+ print(f"Restarting from {checkpoint_name}")
72
+ adapters_weights = torch.load(checkpoint_name)
73
+ rqllama.model = set_peft_model_state_dict(rqllama.model, adapters_weights)
74
+ else:
75
+ if local_rank == 0:
76
+ print(f"Checkpoint {checkpoint_name} not found")
77
+
78
+ if local_rank == 0:
79
+ rqllama.model.print_trainable_parameters()
80
+
81
+ if not ddp and torch.cuda.device_count() > 1:
82
+ rqllama.is_parallelizable = True
83
+ rqllama.model_parallel = True
84
+
85
+ collator = VanillaCollator(args, rqllama.tokenizer)
86
+
87
+ trainer = transformers.Trainer(
88
+ model = rqllama,
89
+ train_dataset = train_data,
90
+ eval_dataset = valid_data,
91
+ args = transformers.TrainingArguments(
92
+ seed = args.seed,
93
+ per_device_train_batch_size = args.per_device_batch_size,
94
+ per_device_eval_batch_size = args.per_device_batch_size,
95
+ gradient_accumulation_steps = args.gradient_accumulation_steps,
96
+ warmup_ratio = args.warmup_ratio,
97
+ num_train_epochs = args.epochs,
98
+ learning_rate = args.learning_rate,
99
+ weight_decay = args.weight_decay,
100
+ lr_scheduler_type = args.lr_scheduler_type,
101
+ fp16 = args.fp16,
102
+ bf16 = args.bf16,
103
+ logging_steps = args.logging_step,
104
+ optim = args.optim,
105
+ gradient_checkpointing = True,
106
+ evaluation_strategy = args.save_and_eval_strategy,
107
+ save_strategy = args.save_and_eval_strategy,
108
+ eval_steps = args.save_and_eval_steps,
109
+ save_steps = args.save_and_eval_steps,
110
+ output_dir = args.output_dir,
111
+ save_total_limit = 5,
112
+ load_best_model_at_end = True,
113
+ deepspeed = args.deepspeed,
114
+ ddp_find_unused_parameters = False if ddp else None,
115
+ report_to = None,
116
+ eval_delay = 1 if args.save_and_eval_strategy=="epoch" else 2000,
117
+ dataloader_num_workers = args.dataloader_num_workers,
118
+ dataloader_prefetch_factor = args.dataloader_prefetch_factor,
119
+ remove_unused_columns = args.remove_unused_columns,
120
+ ),
121
+ tokenizer = rqllama.tokenizer,
122
+ data_collator = collator,
123
+ )
124
+ rqllama.config.use_cache = False
125
+
126
+ if torch.__version__ >= "2" and sys.platform != "win32":
127
+ rqllama = torch.compile(rqllama)
128
+
129
+ trainer.train(resume_from_checkpoint = args.resume_from_checkpoint)
130
+
131
+ trainer.save_state()
132
+ trainer.save_model(output_dir = args.output_dir)
133
+
134
+ if local_rank == 0:
135
+ # print('rqllama pre-train finished.')
136
+
137
+ import smtplib
138
+ from email.mime.text import MIMEText
139
+ mail_host = 'smtp.qq.com'
140
+ mail_code = 'ouzplpngooqndjcb'
141
+ sender = '1849334588@qq.com'
142
+ receiver = 'esperanto1949@foxmail.com'
143
+
144
+ task = '[v53: pretrain tt.ins.5e-4 w/o projector]'
145
+ message = MIMEText('Task {task} Finished'.format(task = task), 'plain', 'utf-8')
146
+ message['Subject'] = 'Auto Email'
147
+ message['From'] = sender
148
+ message['To'] = receiver
149
+
150
+ server = smtplib.SMTP_SSL("smtp.qq.com", 465)
151
+ server.login(sender, mail_code)
152
+ server.sendmail(sender, receiver, message.as_string())
153
+
154
+ server.quit()