|
|
import copy
|
|
|
import random
|
|
|
import argparse
|
|
|
import os
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from torch.utils.data import Dataset
|
|
|
from tqdm import tqdm
|
|
|
from collections import defaultdict
|
|
|
import torch.distributed as dist
|
|
|
import logging
|
|
|
import re
|
|
|
import pdb
|
|
|
import json
|
|
|
from prompt_finetune import sft_prompt, all_prompt
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
|
class BaseDataset(Dataset):
|
|
|
|
|
|
def __init__(self, args):
|
|
|
super().__init__()
|
|
|
|
|
|
self.args = args
|
|
|
self.dataset = args.dataset
|
|
|
self.data_path = os.path.join(args.data_path, self.dataset)
|
|
|
|
|
|
self.max_his_len = args.max_his_len
|
|
|
self.his_sep = args.his_sep
|
|
|
self.index_file = args.index_file
|
|
|
self.add_prefix = args.add_prefix
|
|
|
|
|
|
self.new_tokens = None
|
|
|
self.allowed_tokens = None
|
|
|
self.all_items = None
|
|
|
|
|
|
|
|
|
def _load_data(self):
|
|
|
|
|
|
with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
|
|
|
self.indices = json.load(f)
|
|
|
|
|
|
def get_new_tokens(self):
|
|
|
|
|
|
if self.new_tokens is not None:
|
|
|
return self.new_tokens
|
|
|
|
|
|
self.new_tokens = set()
|
|
|
for index in self.indices.values():
|
|
|
for token in index:
|
|
|
self.new_tokens.add(token)
|
|
|
self.new_tokens = sorted(list(self.new_tokens))
|
|
|
|
|
|
return self.new_tokens
|
|
|
|
|
|
def get_all_items(self):
|
|
|
|
|
|
if self.all_items is not None:
|
|
|
return self.all_items
|
|
|
|
|
|
self.all_items = set()
|
|
|
for index in self.indices.values():
|
|
|
self.all_items.add("".join(index))
|
|
|
|
|
|
return self.all_items
|
|
|
|
|
|
def get_prefix_allowed_tokens_fn(self, tokenizer):
|
|
|
|
|
|
|
|
|
if self.allowed_tokens is None:
|
|
|
self.allowed_tokens = {}
|
|
|
for index in self.indices.values():
|
|
|
for i, token in enumerate(index):
|
|
|
token_id = tokenizer(token)["input_ids"][1]
|
|
|
if i not in self.allowed_tokens.keys():
|
|
|
self.allowed_tokens[i] = set()
|
|
|
self.allowed_tokens[i].add(token_id)
|
|
|
self.allowed_tokens[len(self.allowed_tokens.keys())] = set([tokenizer.eos_token_id])
|
|
|
sep = tokenizer("Response:")["input_ids"][1:]
|
|
|
|
|
|
def prefix_allowed_tokens_fn(batch_id, sentence):
|
|
|
sentence = sentence.tolist()
|
|
|
reversed_sent = sentence[::-1]
|
|
|
for i in range(len(reversed_sent)):
|
|
|
if reversed_sent[i:i + len(sep)] == sep[::-1]:
|
|
|
|
|
|
return list(self.allowed_tokens[i])
|
|
|
|
|
|
return prefix_allowed_tokens_fn
|
|
|
|
|
|
def _process_data(self):
|
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
class SeqRecFinetune(BaseDataset):
|
|
|
|
|
|
def __init__(self, args, mode="train",
|
|
|
prompt_sample_num=1, prompt_id=0, sample_num=-1):
|
|
|
super().__init__(args)
|
|
|
|
|
|
self.mode = mode
|
|
|
self.prompt_sample_num = prompt_sample_num
|
|
|
self.prompt_id = prompt_id
|
|
|
self.sample_num = sample_num
|
|
|
|
|
|
self.prompts = all_prompt["seqrec"]
|
|
|
|
|
|
|
|
|
|
|
|
self._load_data()
|
|
|
self._remap_items()
|
|
|
|
|
|
|
|
|
if self.mode == 'train':
|
|
|
self.inter_data = self._process_train_data()
|
|
|
elif self.mode == 'valid':
|
|
|
self.sample_valid = args.sample_valid
|
|
|
self.valid_prompt_id = args.valid_prompt_id
|
|
|
self.inter_data = self._process_valid_data()
|
|
|
self._construct_valid_text()
|
|
|
elif self.mode == 'test':
|
|
|
self.inter_data = self._process_test_data()
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
def _load_data(self):
|
|
|
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
|
|
self.inters = json.load(f)
|
|
|
with open(self.index_file, 'r') as f:
|
|
|
self.indices = json.load(f)
|
|
|
|
|
|
|
|
|
def _remap_items(self):
|
|
|
|
|
|
self.remapped_inters = dict()
|
|
|
for uid, items in self.inters.items():
|
|
|
new_items = ["".join(self.indices[str(i)]) for i in items]
|
|
|
self.remapped_inters[uid] = new_items
|
|
|
|
|
|
|
|
|
def _process_train_data(self):
|
|
|
|
|
|
inter_data = []
|
|
|
for uid in self.remapped_inters:
|
|
|
items = self.remapped_inters[uid][:-2]
|
|
|
for i in range(1, len(items)):
|
|
|
one_data = dict()
|
|
|
|
|
|
one_data["item"] = items[i]
|
|
|
history = items[:i]
|
|
|
if self.max_his_len > 0:
|
|
|
history = history[-self.max_his_len:]
|
|
|
if self.add_prefix:
|
|
|
history = [str(k+1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
|
|
one_data["inters"] = self.his_sep.join(history)
|
|
|
inter_data.append(one_data)
|
|
|
|
|
|
return inter_data
|
|
|
|
|
|
def _process_valid_data(self):
|
|
|
|
|
|
inter_data = []
|
|
|
for uid in self.remapped_inters:
|
|
|
items = self.remapped_inters[uid]
|
|
|
one_data = dict()
|
|
|
|
|
|
one_data["item"] = items[-2]
|
|
|
history = items[:-2]
|
|
|
if self.max_his_len > 0:
|
|
|
history = history[-self.max_his_len:]
|
|
|
if self.add_prefix:
|
|
|
history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
|
|
one_data["inters"] = self.his_sep.join(history)
|
|
|
inter_data.append(one_data)
|
|
|
|
|
|
return inter_data
|
|
|
|
|
|
def _process_test_data(self):
|
|
|
|
|
|
inter_data = []
|
|
|
for uid in self.remapped_inters:
|
|
|
items = self.remapped_inters[uid]
|
|
|
one_data = dict()
|
|
|
|
|
|
one_data["item"] = items[-1]
|
|
|
history = items[:-1]
|
|
|
if self.max_his_len > 0:
|
|
|
history = history[-self.max_his_len:]
|
|
|
if self.add_prefix:
|
|
|
history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
|
|
one_data["inters"] = self.his_sep.join(history)
|
|
|
inter_data.append(one_data)
|
|
|
|
|
|
if self.sample_num > 0:
|
|
|
all_inter_idx = range(len(inter_data))
|
|
|
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
|
|
inter_data = np.array(inter_data)[sample_idx].tolist()
|
|
|
|
|
|
return inter_data
|
|
|
|
|
|
def set_prompt(self, prompt_id):
|
|
|
|
|
|
self.prompt_id = prompt_id
|
|
|
|
|
|
def __len__(self):
|
|
|
if self.mode == 'train':
|
|
|
return len(self.inter_data) * self.prompt_sample_num
|
|
|
elif self.mode == 'valid':
|
|
|
return len(self.valid_text_data)
|
|
|
elif self.mode == 'test':
|
|
|
return len(self.inter_data)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def _construct_valid_text(self):
|
|
|
self.valid_text_data = []
|
|
|
if self.sample_valid:
|
|
|
all_prompt_ids = range(len(self.prompts))
|
|
|
for i in range(len(self.inter_data)):
|
|
|
d = self.inter_data[i]
|
|
|
prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
|
|
|
for prompt_id in prompt_ids:
|
|
|
prompt = self.prompts[prompt_id]
|
|
|
input, output = self._get_text_data(d, prompt)
|
|
|
self.valid_text_data.append({"input_ids": input, "labels": output})
|
|
|
else:
|
|
|
self.prompt_sample_num = 1
|
|
|
prompt = self.prompts[self.valid_prompt_id]
|
|
|
for i in range(len(self.inter_data)):
|
|
|
d = self.inter_data[i]
|
|
|
input, output = self._get_text_data(d, prompt)
|
|
|
self.valid_text_data.append({"input_ids": input, "labels": output})
|
|
|
|
|
|
def _get_text_data(self, data, prompt):
|
|
|
|
|
|
instruction = prompt["instruction"].format(**data)
|
|
|
response = prompt["response"].format(**data)
|
|
|
|
|
|
input = sft_prompt.format(instruction = instruction, response = "")
|
|
|
output = sft_prompt.format(instruction = instruction, response = response)
|
|
|
|
|
|
if self.mode == 'test':
|
|
|
return input, response
|
|
|
|
|
|
return input, output
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
|
if self.mode == 'valid':
|
|
|
return self.valid_text_data[index]
|
|
|
|
|
|
idx = index // self.prompt_sample_num
|
|
|
d = self.inter_data[idx]
|
|
|
|
|
|
|
|
|
if self.mode == 'train':
|
|
|
prompt_id = random.randint(0, len(self.prompts) - 1)
|
|
|
elif self.mode == 'test':
|
|
|
prompt_id = self.prompt_id
|
|
|
|
|
|
prompt = self.prompts[prompt_id]
|
|
|
|
|
|
input, output = self._get_text_data(d, prompt)
|
|
|
|
|
|
|
|
|
|
|
|
return dict(input_ids=input, labels=output)
|
|
|
|
|
|
|
|
|
class FusionSeqRecFinetune(BaseDataset):
|
|
|
|
|
|
def __init__(self, args, mode="train",
|
|
|
prompt_sample_num=1, prompt_id=0, sample_num=-1):
|
|
|
super().__init__(args)
|
|
|
|
|
|
self.mode = mode
|
|
|
self.prompt_sample_num = prompt_sample_num
|
|
|
self.prompt_id = prompt_id
|
|
|
self.sample_num = sample_num
|
|
|
|
|
|
self.prompts = all_prompt["fusionseqrec"]
|
|
|
|
|
|
|
|
|
self._load_data()
|
|
|
|
|
|
|
|
|
|
|
|
if self.mode == 'train':
|
|
|
self.inter_data = self._process_train_data()
|
|
|
elif self.mode == 'valid':
|
|
|
self.sample_valid = args.sample_valid
|
|
|
self.valid_prompt_id = args.valid_prompt_id
|
|
|
self.inter_data = self._process_valid_data()
|
|
|
self._construct_valid_text()
|
|
|
elif self.mode == 'test':
|
|
|
self.inter_data = self._process_test_data()
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
def _load_data(self):
|
|
|
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
|
|
self.inters = json.load(f)
|
|
|
with open(self.index_file, 'r') as f:
|
|
|
self.indices = json.load(f)
|
|
|
|
|
|
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
|
|
|
self.item_feat = json.load(f)
|
|
|
|
|
|
def _process_train_data(self):
|
|
|
|
|
|
inter_data = []
|
|
|
for uid in self.inters:
|
|
|
items = self.inters[uid][:-2]
|
|
|
for i in range(1, len(items)):
|
|
|
one_data = dict()
|
|
|
|
|
|
one_data["item"] = "".join(self.indices[str(items[i])])
|
|
|
one_data["title"] = self.item_feat[str(items[i])]["title"].strip().strip(".!?,;:`")
|
|
|
one_data["description"] = self.item_feat[str(items[i])]["description"]
|
|
|
history = items[:i]
|
|
|
if self.max_his_len > 0:
|
|
|
history = history[-self.max_his_len:]
|
|
|
inters = ["".join(self.indices[str(j)]) for j in history]
|
|
|
inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
|
|
|
|
|
|
|
|
|
if self.add_prefix:
|
|
|
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
|
|
inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
|
|
|
|
|
|
one_data["inters"] = self.his_sep.join(inters)
|
|
|
one_data["inter_titles"] = self.his_sep.join(inter_titles)
|
|
|
inter_data.append(one_data)
|
|
|
|
|
|
if self.sample_num > 0:
|
|
|
all_inter_idx = range(len(inter_data))
|
|
|
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
|
|
inter_data = np.array(inter_data)[sample_idx].tolist()
|
|
|
|
|
|
return inter_data
|
|
|
|
|
|
def _process_valid_data(self):
|
|
|
|
|
|
inter_data = []
|
|
|
for uid in self.inters:
|
|
|
items = self.inters[uid]
|
|
|
one_data = dict()
|
|
|
one_data["item"] = "".join(self.indices[str(items[-2])])
|
|
|
one_data["title"] = self.item_feat[str(items[-2])]["title"].strip().strip(".!?,;:`")
|
|
|
one_data["description"] = self.item_feat[str(items[-2])]["description"]
|
|
|
|
|
|
|
|
|
history = items[:-2]
|
|
|
if self.max_his_len > 0:
|
|
|
history = history[-self.max_his_len:]
|
|
|
inters = ["".join(self.indices[str(j)]) for j in history]
|
|
|
inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
|
|
|
|
|
|
if self.add_prefix:
|
|
|
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
|
|
inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
|
|
|
|
|
|
one_data["inters"] = self.his_sep.join(inters)
|
|
|
one_data["inter_titles"] = self.his_sep.join(inter_titles)
|
|
|
inter_data.append(one_data)
|
|
|
|
|
|
if self.sample_num > 0:
|
|
|
all_inter_idx = range(len(inter_data))
|
|
|
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
|
|
inter_data = np.array(inter_data)[sample_idx].tolist()
|
|
|
|
|
|
return inter_data
|
|
|
|
|
|
def _process_test_data(self):
|
|
|
|
|
|
inter_data = []
|
|
|
for uid in self.inters:
|
|
|
items = self.inters[uid]
|
|
|
one_data = dict()
|
|
|
one_data["item"] = "".join(self.indices[str(items[-1])])
|
|
|
one_data["title"] = self.item_feat[str(items[-1])]["title"].strip().strip(".!?,;:`")
|
|
|
one_data["description"] = self.item_feat[str(items[-1])]["description"]
|
|
|
|
|
|
history = items[:-1]
|
|
|
if self.max_his_len > 0:
|
|
|
history = history[-self.max_his_len:]
|
|
|
inters = ["".join(self.indices[str(j)]) for j in history]
|
|
|
inter_titles = ["\"" + self.item_feat[str(j)]["title"].strip().strip(".!?,;:`") + "\"" for j in history]
|
|
|
|
|
|
if self.add_prefix:
|
|
|
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
|
|
inter_titles = [str(k + 1) + ". " + item_title for k, item_title in enumerate(inter_titles)]
|
|
|
|
|
|
one_data["inters"] = self.his_sep.join(inters)
|
|
|
one_data["inter_titles"] = self.his_sep.join(inter_titles)
|
|
|
inter_data.append(one_data)
|
|
|
|
|
|
if self.sample_num > 0:
|
|
|
all_inter_idx = range(len(inter_data))
|
|
|
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
|
|
inter_data = np.array(inter_data)[sample_idx].tolist()
|
|
|
|
|
|
return inter_data
|
|
|
|
|
|
def set_prompt(self, prompt_id):
|
|
|
|
|
|
self.prompt_id = prompt_id
|
|
|
|
|
|
def __len__(self):
|
|
|
if self.mode == 'train':
|
|
|
return len(self.inter_data) * self.prompt_sample_num
|
|
|
elif self.mode == 'valid':
|
|
|
return len(self.valid_text_data)
|
|
|
elif self.mode == 'test':
|
|
|
return len(self.inter_data)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
def _construct_valid_text(self):
|
|
|
self.valid_text_data = []
|
|
|
if self.sample_valid:
|
|
|
all_prompt_ids = range(len(self.prompts))
|
|
|
for i in range(len(self.inter_data)):
|
|
|
d = self.inter_data[i]
|
|
|
prompt_ids = np.random.choice(all_prompt_ids, self.prompt_sample_num, replace=False)
|
|
|
for prompt_id in prompt_ids:
|
|
|
prompt = self.prompts[prompt_id]
|
|
|
input, output = self._get_text_data(d, prompt)
|
|
|
self.valid_text_data.append({"input_ids": input, "labels": output})
|
|
|
else:
|
|
|
self.prompt_sample_num = 1
|
|
|
prompt = self.prompts[self.valid_prompt_id]
|
|
|
for i in range(len(self.inter_data)):
|
|
|
d = self.inter_data[i]
|
|
|
input, output = self._get_text_data(d, prompt)
|
|
|
self.valid_text_data.append({"input_ids": input, "labels": output})
|
|
|
|
|
|
def _get_text_data(self, data, prompt):
|
|
|
|
|
|
instruction = prompt["instruction"].format(**data)
|
|
|
response = prompt["response"].format(**data)
|
|
|
|
|
|
input = sft_prompt.format(instruction=instruction, response="")
|
|
|
output = sft_prompt.format(instruction=instruction, response=response)
|
|
|
|
|
|
if self.mode == 'test':
|
|
|
return input, response
|
|
|
|
|
|
return input, output
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
|
if self.mode == 'valid':
|
|
|
return self.valid_text_data[index]
|
|
|
|
|
|
idx = index // self.prompt_sample_num
|
|
|
d = self.inter_data[idx]
|
|
|
|
|
|
if self.mode == 'train':
|
|
|
prompt_id = random.randint(0, len(self.prompts) - 1)
|
|
|
elif self.mode == 'test':
|
|
|
prompt_id = self.prompt_id
|
|
|
|
|
|
prompt = self.prompts[prompt_id]
|
|
|
|
|
|
input, output = self._get_text_data(d, prompt)
|
|
|
|
|
|
|
|
|
return dict(input_ids=input, labels=output)
|
|
|
|
|
|
|
|
|
class ItemFeatFinetune(BaseDataset):
|
|
|
|
|
|
def __init__(self, args, task="item2index", prompt_sample_num=1, sample_num=-1):
|
|
|
super().__init__(args)
|
|
|
|
|
|
self.task = task.lower()
|
|
|
self.prompt_sample_num = prompt_sample_num
|
|
|
self.sample_num = sample_num
|
|
|
|
|
|
self.prompts = all_prompt[self.task]
|
|
|
|
|
|
|
|
|
self._load_data()
|
|
|
self.feat_data = self._process_data()
|
|
|
|
|
|
|
|
|
|
|
|
def _load_data(self):
|
|
|
|
|
|
|
|
|
|
|
|
with open(self.index_file, 'r') as f:
|
|
|
self.indices = json.load(f)
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".item.json"), 'r') as f:
|
|
|
self.item_feat = json.load(f)
|
|
|
|
|
|
|
|
|
def _process_data(self):
|
|
|
|
|
|
feat_data = []
|
|
|
for iid in self.item_feat:
|
|
|
feat = self.item_feat[iid]
|
|
|
index = "".join(self.indices[iid])
|
|
|
feat["item"] = index
|
|
|
feat["title"] = feat["title"].strip().strip(".!?,;:`")
|
|
|
feat_data.append(feat)
|
|
|
|
|
|
if self.sample_num > 0:
|
|
|
all_idx = range(len(feat_data))
|
|
|
sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
|
|
|
|
|
|
feat_data = np.array(feat_data)[sample_idx].tolist()
|
|
|
|
|
|
return feat_data
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.feat_data) * self.prompt_sample_num
|
|
|
|
|
|
def _get_text_data(self, data, prompt):
|
|
|
|
|
|
instruction = prompt["instruction"].format(**data)
|
|
|
response = prompt["response"].format(**data)
|
|
|
|
|
|
input = sft_prompt.format(instruction = instruction, response = "")
|
|
|
output = sft_prompt.format(instruction = instruction, response = response)
|
|
|
|
|
|
return input, output
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
|
idx = index // self.prompt_sample_num
|
|
|
d = self.feat_data[idx]
|
|
|
|
|
|
prompt_id = random.randint(0, len(self.prompts) - 1)
|
|
|
|
|
|
prompt = self.prompts[prompt_id]
|
|
|
|
|
|
input, output = self._get_text_data(d, prompt)
|
|
|
|
|
|
return dict(input_ids=input, labels=output)
|
|
|
|
|
|
|
|
|
class ItemSearchFinetune(BaseDataset):
|
|
|
|
|
|
def __init__(self, args, mode="train",
|
|
|
prompt_sample_num=1, prompt_id=0, sample_num=-1):
|
|
|
super().__init__(args)
|
|
|
|
|
|
self.mode = mode
|
|
|
self.prompt_sample_num = prompt_sample_num
|
|
|
self.prompt_id = prompt_id
|
|
|
self.sample_num = sample_num
|
|
|
|
|
|
self.prompts = all_prompt["itemsearch"]
|
|
|
|
|
|
|
|
|
self._load_data()
|
|
|
self.search_data = self._process_data()
|
|
|
|
|
|
|
|
|
|
|
|
def _load_data(self):
|
|
|
|
|
|
|
|
|
|
|
|
with open(self.index_file, 'r') as f:
|
|
|
self.indices = json.load(f)
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
|
|
|
self.user_info = json.load(f)
|
|
|
|
|
|
|
|
|
def _process_data(self):
|
|
|
|
|
|
search_data = []
|
|
|
user_explicit_preference = self.user_info["user_explicit_preference"]
|
|
|
user_vague_intention = self.user_info["user_vague_intention"]
|
|
|
if self.mode == 'train':
|
|
|
user_vague_intention = user_vague_intention["train"]
|
|
|
elif self.mode == 'test':
|
|
|
user_vague_intention = user_vague_intention["test"]
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
for uid in user_explicit_preference.keys():
|
|
|
one_data = {}
|
|
|
user_ep = user_explicit_preference[uid]
|
|
|
user_vi = user_vague_intention[uid]["querys"]
|
|
|
one_data["explicit_preferences"] = user_ep
|
|
|
one_data["user_related_intention"] = user_vi[0]
|
|
|
one_data["item_related_intention"] = user_vi[1]
|
|
|
|
|
|
iid = user_vague_intention[uid]["item"]
|
|
|
inters = user_vague_intention[uid]["inters"]
|
|
|
|
|
|
index = "".join(self.indices[str(iid)])
|
|
|
one_data["item"] = index
|
|
|
|
|
|
if self.max_his_len > 0:
|
|
|
inters = inters[-self.max_his_len:]
|
|
|
inters = ["".join(self.indices[str(i)]) for i in inters]
|
|
|
if self.add_prefix:
|
|
|
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
|
|
|
|
|
one_data["inters"] = self.his_sep.join(inters)
|
|
|
|
|
|
search_data.append(one_data)
|
|
|
|
|
|
if self.sample_num > 0:
|
|
|
all_idx = range(len(search_data))
|
|
|
sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
|
|
|
|
|
|
search_data = np.array(search_data)[sample_idx].tolist()
|
|
|
|
|
|
return search_data
|
|
|
|
|
|
def set_prompt(self, prompt_id):
|
|
|
self.prompt_id = prompt_id
|
|
|
|
|
|
def __len__(self):
|
|
|
if self.mode == 'train':
|
|
|
return len(self.search_data) * self.prompt_sample_num
|
|
|
elif self.mode == 'test':
|
|
|
return len(self.search_data)
|
|
|
else:
|
|
|
return len(self.search_data)
|
|
|
|
|
|
|
|
|
def _get_text_data(self, data, prompt):
|
|
|
|
|
|
instruction = prompt["instruction"].format(**data)
|
|
|
response = prompt["response"].format(**data)
|
|
|
|
|
|
input = sft_prompt.format(instruction = instruction, response = "")
|
|
|
output = sft_prompt.format(instruction = instruction, response = response)
|
|
|
|
|
|
if self.mode == 'test':
|
|
|
return input, response
|
|
|
|
|
|
return input, output
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
|
idx = index // self.prompt_sample_num
|
|
|
|
|
|
d = self.search_data[idx]
|
|
|
if self.mode == 'train':
|
|
|
prompt_id = random.randint(0, len(self.prompts) - 1)
|
|
|
elif self.mode == 'test':
|
|
|
prompt_id = self.prompt_id
|
|
|
|
|
|
prompt = self.prompts[prompt_id]
|
|
|
|
|
|
d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
|
|
|
all_querys = [d["user_related_intention"], d["item_related_intention"]]
|
|
|
d["query"] = random.choice(all_querys)
|
|
|
|
|
|
input, output = self._get_text_data(d, prompt)
|
|
|
|
|
|
return dict(input_ids=input, labels=output)
|
|
|
|
|
|
|
|
|
|
|
|
class PreferenceObtainFinetune(BaseDataset):
|
|
|
|
|
|
def __init__(self, args, prompt_sample_num=1, sample_num=-1):
|
|
|
super().__init__(args)
|
|
|
|
|
|
self.prompt_sample_num = prompt_sample_num
|
|
|
self.sample_num = sample_num
|
|
|
|
|
|
self.prompts = all_prompt["preferenceobtain"]
|
|
|
|
|
|
|
|
|
self._load_data()
|
|
|
self._remap_items()
|
|
|
|
|
|
self.preference_data = self._process_data()
|
|
|
|
|
|
|
|
|
|
|
|
def _load_data(self):
|
|
|
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".user.json"), 'r') as f:
|
|
|
self.user_info = json.load(f)
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
|
|
self.inters = json.load(f)
|
|
|
|
|
|
|
|
|
with open(self.index_file, 'r') as f:
|
|
|
self.indices = json.load(f)
|
|
|
|
|
|
|
|
|
def _remap_items(self):
|
|
|
|
|
|
self.remapped_inters = dict()
|
|
|
for uid, items in self.inters.items():
|
|
|
new_items = ["".join(self.indices[str(i)]) for i in items]
|
|
|
self.remapped_inters[uid] = new_items
|
|
|
|
|
|
def _process_data(self):
|
|
|
|
|
|
preference_data = []
|
|
|
user_explicit_preference = self.user_info["user_explicit_preference"]
|
|
|
|
|
|
for uid in user_explicit_preference.keys():
|
|
|
one_data = {}
|
|
|
inters = self.remapped_inters[uid][:-3]
|
|
|
user_ep = user_explicit_preference[uid]
|
|
|
|
|
|
if self.max_his_len > 0:
|
|
|
inters = inters[-self.max_his_len:]
|
|
|
if self.add_prefix:
|
|
|
inters = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(inters)]
|
|
|
|
|
|
one_data["explicit_preferences"] = user_ep
|
|
|
one_data["inters"] = self.his_sep.join(inters)
|
|
|
|
|
|
preference_data.append(one_data)
|
|
|
|
|
|
if self.sample_num > 0:
|
|
|
all_idx = range(len(preference_data))
|
|
|
sample_idx = np.random.choice(all_idx, self.sample_num, replace=False)
|
|
|
|
|
|
preference_data = np.array(preference_data)[sample_idx].tolist()
|
|
|
|
|
|
return preference_data
|
|
|
|
|
|
def set_prompt(self, prompt_id):
|
|
|
self.prompt_id = prompt_id
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.preference_data) * self.prompt_sample_num
|
|
|
|
|
|
|
|
|
def _get_text_data(self, data, prompt):
|
|
|
|
|
|
instruction = prompt["instruction"].format(**data)
|
|
|
response = prompt["response"].format(**data)
|
|
|
|
|
|
input = sft_prompt.format(instruction = instruction, response = "")
|
|
|
output = sft_prompt.format(instruction = instruction, response = response)
|
|
|
|
|
|
return input, output
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
|
idx = index // self.prompt_sample_num
|
|
|
|
|
|
d = self.preference_data[idx]
|
|
|
prompt_id = random.randint(0, len(self.prompts) - 1)
|
|
|
|
|
|
prompt = self.prompts[prompt_id]
|
|
|
|
|
|
d["explicit_preference"] = copy.deepcopy(random.choice(d["explicit_preferences"]))
|
|
|
|
|
|
input, output = self._get_text_data(d, prompt)
|
|
|
|
|
|
return dict(input_ids=input, labels=output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SeqRecTestDataset(BaseDataset):
|
|
|
|
|
|
def __init__(self, args, prompt_id=0, sample_num=-1):
|
|
|
super().__init__(args)
|
|
|
|
|
|
self.prompt_id = prompt_id
|
|
|
self.sample_num = sample_num
|
|
|
|
|
|
self.prompt = all_prompt["seqrec"][self.prompt_id]
|
|
|
|
|
|
|
|
|
self._load_data()
|
|
|
self._remap_items()
|
|
|
|
|
|
self.inter_data = self._process_test_data()
|
|
|
|
|
|
def _load_data(self):
|
|
|
|
|
|
with open(os.path.join(self.data_path, self.dataset + ".inter.json"), 'r') as f:
|
|
|
self.inters = json.load(f)
|
|
|
with open(os.path.join(self.data_path, self.dataset + self.index_file), 'r') as f:
|
|
|
self.indices = json.load(f)
|
|
|
|
|
|
|
|
|
def _remap_items(self):
|
|
|
|
|
|
self.remapped_inters = dict()
|
|
|
for uid, items in self.inters.items():
|
|
|
new_items = ["".join(self.indices[str(i)]) for i in items]
|
|
|
self.remapped_inters[uid] = new_items
|
|
|
|
|
|
def _process_test_data(self):
|
|
|
|
|
|
inter_data = []
|
|
|
for uid in self.remapped_inters:
|
|
|
items = self.remapped_inters[uid]
|
|
|
one_data = dict()
|
|
|
|
|
|
one_data["item"] = items[-1]
|
|
|
history = items[:-1]
|
|
|
if self.max_his_len > 0:
|
|
|
history = history[-self.max_his_len:]
|
|
|
if self.add_prefix:
|
|
|
history = [str(k + 1) + ". " + item_idx for k, item_idx in enumerate(history)]
|
|
|
one_data["inters"] = self.his_sep.join(history)
|
|
|
inter_data.append(one_data)
|
|
|
|
|
|
if self.sample_num > 0:
|
|
|
all_inter_idx = range(len(inter_data))
|
|
|
sample_idx = np.random.choice(all_inter_idx, self.sample_num, replace=False)
|
|
|
|
|
|
inter_data = np.array(inter_data)[sample_idx].tolist()
|
|
|
|
|
|
return inter_data
|
|
|
|
|
|
def set_prompt(self, prompt_id):
|
|
|
self.prompt_id = prompt_id
|
|
|
|
|
|
self.prompt = all_prompt["seqrec"][self.prompt_id]
|
|
|
|
|
|
def __len__(self):
|
|
|
|
|
|
return len(self.inter_data)
|
|
|
|
|
|
def _get_text_data(self, data, prompt):
|
|
|
|
|
|
instruction = prompt["instruction"].format(**data)
|
|
|
response = prompt["response"].format(**data)
|
|
|
|
|
|
input = sft_prompt.format(instruction=instruction, response="")
|
|
|
|
|
|
return input, response
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
|
|
d = self.inter_data[index]
|
|
|
input, target = self._get_text_data(d, self.prompt)
|
|
|
|
|
|
return dict(input_ids=input, labels=target) |