| | import os
|
| | import json
|
| | import copy
|
| | import wandb
|
| | import torch
|
| | import torch.nn as nn
|
| | import transformers
|
| | from transformers import LlamaPreTrainedModel, LlamaForCausalLM, LlamaTokenizer, LlamaConfig
|
| |
|
| | from peft import (
|
| | TaskType,
|
| | LoraConfig,
|
| | get_peft_model,
|
| | get_peft_model_state_dict,
|
| | set_peft_model_state_dict,
|
| | )
|
| |
|
| | from index.models import *
|
| | from index.models.rqvae import RQVAE
|
| |
|
| | from torch.nn.init import xavier_normal_
|
| | from sklearn.cluster import KMeans
|
| |
|
| | class LlamaWithRQ(LlamaPreTrainedModel):
|
| | def __init__(self, config):
|
| | super().__init__(config)
|
| |
|
| | args = config.args
|
| |
|
| | tokenizer = LlamaTokenizer.from_pretrained(
|
| | args['base_model'],
|
| | model_max_length = args['model_max_length'],
|
| | padding_side ="right",
|
| | )
|
| | tokenizer.pad_token_id = 0
|
| |
|
| | new_tokens = []
|
| | prefix = ['<a-{}>','<b-{}>','<c-{}>','<d-{}>','<e-{}>']
|
| | for i in range(len(args['num_emb_list'])):
|
| | new_tokens.extend([prefix[i].format(int(x)) for x in range(args['num_emb_list'][i])])
|
| | self.prefix = prefix
|
| |
|
| | tokenizer.add_tokens(new_tokens)
|
| | config.vocab_size = len(tokenizer)
|
| |
|
| | llama_model = LlamaForCausalLM.from_pretrained(args['base_model'])
|
| | llama_model.resize_token_embeddings(len(tokenizer))
|
| |
|
| | lora_config = LoraConfig(
|
| | r = args['lora_r'],
|
| | lora_alpha = args['lora_alpha'],
|
| | target_modules = args['lora_target_modules'].split(","),
|
| | modules_to_save = args['lora_modules_to_save'].split(","),
|
| | lora_dropout = args['lora_dropout'],
|
| | bias = "none",
|
| | inference_mode = False,
|
| | task_type = TaskType.CAUSAL_LM
|
| | )
|
| | llama_model = get_peft_model(llama_model, lora_config)
|
| |
|
| | for n, p in llama_model.named_parameters():
|
| | if "original_module" in n and any(module_name in n for module_name in lora_config.modules_to_save):
|
| | p.requires_grad = False
|
| |
|
| | self.tokenizer = tokenizer
|
| | self.model = llama_model
|
| |
|
| | item_json = os.path.join(args['data_path'], args['dataset'], args['dataset'] + ".item.json")
|
| | with open(item_json, 'r') as f:
|
| | self.item_texts = json.load(f)
|
| |
|
| | self.rqvae = RQVAE(in_dim = config.hidden_size,
|
| | num_emb_list = args['num_emb_list'],
|
| | e_dim = args['e_dim'],
|
| | layers = args['layers'],
|
| | dropout_prob = args['dropout_prob'],
|
| | bn = args['bn'],
|
| | loss_type = args['loss_type'],
|
| | quant_loss_weight = args['quant_loss_weight'],
|
| | kmeans_init = args['kmeans_init'],
|
| | kmeans_iters = args['kmeans_iters'],
|
| | sk_epsilons = args['sk_epsilons'],
|
| | sk_iters = args['sk_iters'])
|
| |
|
| |
|
| | self.args = args
|
| |
|
| | def rqvae_forward(self, inputs, targets, inters, item, task):
|
| | llama_model = self.model.get_decoder()
|
| | if task.lower() in ['seqrec', 'itemsearch']:
|
| |
|
| |
|
| |
|
| | inter_feature_list = []
|
| | inter_emb_list = []
|
| | inter_item_list = inters.split(',')
|
| | for j in range(len(inter_item_list)):
|
| | inter_feature = self.item_texts[inter_item_list[j]]['title'] + ' ' + self.item_texts[inter_item_list[j]]['description']
|
| | inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
|
| | inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask)
|
| | inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1)
|
| | inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True)
|
| | inter_emb_list.append(inter_emb.detach())
|
| | inter_embs = torch.cat(inter_emb_list, dim = 0)
|
| |
|
| | item_feature = self.item_texts[item]['title'] + ' ' + self.item_texts[item]['description']
|
| | item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
|
| | item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask)
|
| | item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1)
|
| | item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True)
|
| | item_emb = item_emb.detach()
|
| |
|
| | rec_embs, rq_loss, rqids = self.rqvae(torch.cat([inter_embs, item_emb], dim = 0))
|
| | rqvae_loss, rec_loss = self.rqvae.compute_loss(rec_embs, rq_loss, torch.cat([inter_embs, item_emb], dim = 0))
|
| |
|
| | inters_rqids = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[:-1]
|
| | item_rqid = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[-1]
|
| |
|
| | text_rqids = {}
|
| | code = ''
|
| | for rqid in inters_rqids:
|
| | for k, idx in enumerate(rqid):
|
| | code = code + self.prefix[k].format(idx)
|
| | code = code + ', '
|
| | text_rqids['inters'] = code[:-2]
|
| | code = ''
|
| | for k, idx in enumerate(item_rqid):
|
| | code = code + self.prefix[k].format(idx)
|
| | text_rqids['item'] = code
|
| |
|
| | inputs = inputs.format(inters = text_rqids['inters'])
|
| | targets = targets.format(inters = text_rqids['inters'], item = text_rqids['item'])
|
| |
|
| | elif task.lower() in ['inters2title','inters2description','preferenceobtain']:
|
| |
|
| | inter_feature_list = []
|
| | inter_emb_list = []
|
| | inter_item_list = inters.split(',')
|
| | for j in range(len(inter_item_list)):
|
| | inter_feature = self.item_texts[inter_item_list[j]]['title'] + ' ' + self.item_texts[inter_item_list[j]]['description']
|
| | inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
|
| | inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask)
|
| | inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1)
|
| | inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True)
|
| | inter_emb_list.append(inter_emb.detach())
|
| | inter_embs = torch.cat(inter_emb_list, dim = 0)
|
| |
|
| | rec_embs, rq_loss, rqids = self.rqvae(inter_embs)
|
| | rqvae_loss, rec_loss = self.rqvae.compute_loss(rec_embs, rq_loss, inter_embs)
|
| |
|
| | inters_rqids = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()
|
| | code = ''
|
| | for rqid in inters_rqids:
|
| | for k, idx in enumerate(rqid):
|
| | code = code + self.prefix[k].format(idx)
|
| | code = code + ', '
|
| |
|
| | inputs = inputs.format(inters = code[:-2])
|
| | targets = targets.format(inters = code[:-2])
|
| |
|
| | elif task.lower() in ['item2index','index2item','intertitles2item','query2item']:
|
| |
|
| | item_feature = self.item_texts[item]['title'] + ' ' + self.item_texts[item]['description']
|
| | item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
|
| | item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask)
|
| | item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1)
|
| | item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True)
|
| | item_emb = item_emb.detach()
|
| |
|
| | rec_embs, rq_loss, rqids = self.rqvae(item_emb)
|
| | rqvae_loss, rec_loss = self.rqvae.compute_loss(rec_embs, rq_loss, item_emb)
|
| |
|
| | item_rqid = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[0]
|
| | code = ''
|
| | for k, idx in enumerate(item_rqid):
|
| | code = code + self.prefix[k].format(idx)
|
| |
|
| | targets = targets.format(item = code)
|
| | else:
|
| | raise NotImplementedError
|
| |
|
| | return inputs, targets, rqvae_loss, rec_embs.shape[0]
|
| |
|
| | def forward(self, input_ids, labels, inters, item, task):
|
| | '''
|
| | 'input_ids':
|
| | [
|
| | "Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
| | ### Instruction:
|
| | Using the user's historical interactions as input data, suggest the next item that the user is highly likely to enjoy.
|
| | The historical interactions are provided as follows: {inters}.
|
| | ### Response:",
|
| |
|
| | 'Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
| | ### Instruction:
|
| | You have obtained the ordered list of user historical interaction items, which is as follows: {inters}.
|
| | Using this history as a reference, please select the next item to recommend to the user.
|
| | ### Response:'
|
| | ],
|
| |
|
| | 'labels':
|
| | [
|
| | "Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
| | ### Instruction:
|
| | Using the user's historical interactions as input data, suggest the next item that the user is highly likely to enjoy.
|
| | The historical interactions are provided as follows: {inters}.
|
| | ### Response:{item}",
|
| |
|
| | 'Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
| | ### Instruction:
|
| | You have obtained the ordered list of user historical interaction items, which is as follows: {inters}.
|
| | Using this history as a reference, please select the next item to recommend to the user.
|
| | ### Response:{item}'
|
| | ],
|
| |
|
| | 'inters': ['0', '0,1'],
|
| | 'item': ['1', '2'],
|
| | 'task': ['seqrec', 'seqrec']
|
| | '''
|
| | assert len(set([len(input_ids), len(labels), len(inters), len(item), len(task)])) == 1
|
| | num_data = len(task)
|
| |
|
| | total_rqvae_loss = 0
|
| | total_num_sample = 0
|
| | for i in range(num_data):
|
| | input_ids[i], labels[i], rqvae_loss, num_sample = self.rqvae_forward(input_ids[i], labels[i], inters[i], item[i], task[i])
|
| | total_rqvae_loss += rqvae_loss
|
| | total_num_sample += num_sample
|
| |
|
| | input_data = self.tokenizer(
|
| | text = labels,
|
| | text_target = input_ids,
|
| | return_tensors = 'pt',
|
| | padding = 'longest',
|
| | truncation = True,
|
| | max_length = self.tokenizer.model_max_length,
|
| | return_attention_mask = True
|
| | ).to(self.model.device)
|
| |
|
| | labels = copy.deepcopy(input_data["input_ids"])
|
| | if self.args['only_train_response']:
|
| | labels[labels == self.tokenizer.pad_token_id] = -100
|
| | labels[torch.where(input_data["labels"] != self.tokenizer.pad_token_id)] = -100
|
| |
|
| | input_data["labels"] = labels
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | result = self.model(**input_data)
|
| | wandb.log({'Llama_Loss': result.loss, 'RQVAE_Loss': total_rqvae_loss / total_num_sample})
|
| | result.loss += total_rqvae_loss / total_num_sample
|
| | wandb.log({'Total_Loss': result.loss})
|
| | return result
|
| |
|
| | def floating_point_ops(self, inputs):
|
| | return 0 |