import os import json import copy import wandb import torch import torch.nn as nn import transformers from transformers import LlamaPreTrainedModel, LlamaForCausalLM, LlamaTokenizer, LlamaConfig from peft import ( TaskType, LoraConfig, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict, ) from index.models import * from index.models.rqvae import RQVAE from torch.nn.init import xavier_normal_ from sklearn.cluster import KMeans class LlamaWithRQ(LlamaPreTrainedModel): def __init__(self, config): super().__init__(config) args = config.args tokenizer = LlamaTokenizer.from_pretrained( args['base_model'], model_max_length = args['model_max_length'], padding_side ="right", ) tokenizer.pad_token_id = 0 new_tokens = [] prefix = ['','','','',''] for i in range(len(args['num_emb_list'])): new_tokens.extend([prefix[i].format(int(x)) for x in range(args['num_emb_list'][i])]) self.prefix = prefix tokenizer.add_tokens(new_tokens) config.vocab_size = len(tokenizer) llama_model = LlamaForCausalLM.from_pretrained(args['base_model']) llama_model.resize_token_embeddings(len(tokenizer)) lora_config = LoraConfig( r = args['lora_r'], lora_alpha = args['lora_alpha'], target_modules = args['lora_target_modules'].split(","), modules_to_save = args['lora_modules_to_save'].split(","), lora_dropout = args['lora_dropout'], bias = "none", inference_mode = False, task_type = TaskType.CAUSAL_LM ) llama_model = get_peft_model(llama_model, lora_config) for n, p in llama_model.named_parameters(): if "original_module" in n and any(module_name in n for module_name in lora_config.modules_to_save): p.requires_grad = False self.tokenizer = tokenizer self.model = llama_model item_json = os.path.join(args['data_path'], args['dataset'], args['dataset'] + ".item.json") with open(item_json, 'r') as f: self.item_texts = json.load(f) self.rqvae = RQVAE(in_dim = config.hidden_size, num_emb_list = args['num_emb_list'], e_dim = args['e_dim'], layers = args['layers'], dropout_prob = args['dropout_prob'], bn = args['bn'], loss_type = args['loss_type'], quant_loss_weight = args['quant_loss_weight'], kmeans_init = args['kmeans_init'], kmeans_iters = args['kmeans_iters'], sk_epsilons = args['sk_epsilons'], sk_iters = args['sk_iters']) # self.projector = nn.Linear(args['e_dim'], config.hidden_size) self.args = args def rqvae_forward(self, inputs, targets, inters, item, task): llama_model = self.model.get_decoder() if task.lower() in ['seqrec', 'itemsearch']: # inputs, targets, inters, item # item-id to text inter_feature_list = [] inter_emb_list = [] inter_item_list = inters.split(',') for j in range(len(inter_item_list)): inter_feature = self.item_texts[inter_item_list[j]]['title'] + ' ' + self.item_texts[inter_item_list[j]]['description'] inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask) inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1) inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True) inter_emb_list.append(inter_emb.detach()) inter_embs = torch.cat(inter_emb_list, dim = 0) item_feature = self.item_texts[item]['title'] + ' ' + self.item_texts[item]['description'] item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask) item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1) item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True) item_emb = item_emb.detach() rec_embs, rq_loss, rqids = self.rqvae(torch.cat([inter_embs, item_emb], dim = 0)) rqvae_loss, rec_loss = self.rqvae.compute_loss(rec_embs, rq_loss, torch.cat([inter_embs, item_emb], dim = 0)) inters_rqids = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[:-1] item_rqid = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[-1] text_rqids = {} code = '' for rqid in inters_rqids: for k, idx in enumerate(rqid): code = code + self.prefix[k].format(idx) code = code + ', ' text_rqids['inters'] = code[:-2] code = '' for k, idx in enumerate(item_rqid): code = code + self.prefix[k].format(idx) text_rqids['item'] = code inputs = inputs.format(inters = text_rqids['inters']) targets = targets.format(inters = text_rqids['inters'], item = text_rqids['item']) elif task.lower() in ['inters2title','inters2description','preferenceobtain']: # inputs, targets, inters inter_feature_list = [] inter_emb_list = [] inter_item_list = inters.split(',') for j in range(len(inter_item_list)): inter_feature = self.item_texts[inter_item_list[j]]['title'] + ' ' + self.item_texts[inter_item_list[j]]['description'] inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask) inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1) inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True) inter_emb_list.append(inter_emb.detach()) inter_embs = torch.cat(inter_emb_list, dim = 0) rec_embs, rq_loss, rqids = self.rqvae(inter_embs) rqvae_loss, rec_loss = self.rqvae.compute_loss(rec_embs, rq_loss, inter_embs) inters_rqids = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist() code = '' for rqid in inters_rqids: for k, idx in enumerate(rqid): code = code + self.prefix[k].format(idx) code = code + ', ' inputs = inputs.format(inters = code[:-2]) targets = targets.format(inters = code[:-2]) elif task.lower() in ['item2index','index2item','intertitles2item','query2item']: # inputs, targets, item item_feature = self.item_texts[item]['title'] + ' ' + self.item_texts[item]['description'] item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device) item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask) item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1) item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True) item_emb = item_emb.detach() rec_embs, rq_loss, rqids = self.rqvae(item_emb) rqvae_loss, rec_loss = self.rqvae.compute_loss(rec_embs, rq_loss, item_emb) item_rqid = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[0] code = '' for k, idx in enumerate(item_rqid): code = code + self.prefix[k].format(idx) targets = targets.format(item = code) else: raise NotImplementedError return inputs, targets, rqvae_loss, rec_embs.shape[0] def forward(self, input_ids, labels, inters, item, task): ''' 'input_ids': [ "Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Using the user's historical interactions as input data, suggest the next item that the user is highly likely to enjoy. The historical interactions are provided as follows: {inters}. ### Response:", 'Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: You have obtained the ordered list of user historical interaction items, which is as follows: {inters}. Using this history as a reference, please select the next item to recommend to the user. ### Response:' ], 'labels': [ "Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Using the user's historical interactions as input data, suggest the next item that the user is highly likely to enjoy. The historical interactions are provided as follows: {inters}. ### Response:{item}", 'Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: You have obtained the ordered list of user historical interaction items, which is as follows: {inters}. Using this history as a reference, please select the next item to recommend to the user. ### Response:{item}' ], 'inters': ['0', '0,1'], 'item': ['1', '2'], 'task': ['seqrec', 'seqrec'] ''' assert len(set([len(input_ids), len(labels), len(inters), len(item), len(task)])) == 1 num_data = len(task) total_rqvae_loss = 0 total_num_sample = 0 for i in range(num_data): input_ids[i], labels[i], rqvae_loss, num_sample = self.rqvae_forward(input_ids[i], labels[i], inters[i], item[i], task[i]) total_rqvae_loss += rqvae_loss total_num_sample += num_sample input_data = self.tokenizer( text = labels, text_target = input_ids, return_tensors = 'pt', padding = 'longest', truncation = True, max_length = self.tokenizer.model_max_length, return_attention_mask = True ).to(self.model.device) labels = copy.deepcopy(input_data["input_ids"]) if self.args['only_train_response']: labels[labels == self.tokenizer.pad_token_id] = -100 labels[torch.where(input_data["labels"] != self.tokenizer.pad_token_id)] = -100 input_data["labels"] = labels # codebook_embedding = [] # for i in range(len(self.rqvae.num_emb_list)): # codebook_embedding.append(self.rqvae.rq.vq_layers[i].embedding.weight.data) # codebook_embedding = torch.cat(codebook_embedding, dim = 0) # codebook_embedding = self.projector(codebook_embedding) # self.model.model.model.embed_tokens.weight.data[-codebook_embedding.shape[0]:] = codebook_embedding result = self.model(**input_data) wandb.log({'Llama_Loss': result.loss, 'RQVAE_Loss': total_rqvae_loss / total_num_sample}) result.loss += total_rqvae_loss / total_num_sample wandb.log({'Total_Loss': result.loss}) return result def floating_point_ops(self, inputs): return 0