Benchmark-Single / rq_llama.py
Junyin's picture
Add files using upload-large-folder tool
811e03d verified
import os
import json
import copy
import wandb
import torch
import torch.nn as nn
import transformers
from transformers import LlamaPreTrainedModel, LlamaForCausalLM, LlamaTokenizer, LlamaConfig
from peft import (
TaskType,
LoraConfig,
get_peft_model,
get_peft_model_state_dict,
set_peft_model_state_dict,
)
from index.models import *
from index.models.rqvae import RQVAE
from torch.nn.init import xavier_normal_
from sklearn.cluster import KMeans
class LlamaWithRQ(LlamaPreTrainedModel):
def __init__(self, config):
super().__init__(config)
args = config.args
tokenizer = LlamaTokenizer.from_pretrained(
args['base_model'],
model_max_length = args['model_max_length'],
padding_side ="right",
)
tokenizer.pad_token_id = 0
new_tokens = []
prefix = ['<a-{}>','<b-{}>','<c-{}>','<d-{}>','<e-{}>']
for i in range(len(args['num_emb_list'])):
new_tokens.extend([prefix[i].format(int(x)) for x in range(args['num_emb_list'][i])])
self.prefix = prefix
tokenizer.add_tokens(new_tokens)
config.vocab_size = len(tokenizer)
llama_model = LlamaForCausalLM.from_pretrained(args['base_model'])
llama_model.resize_token_embeddings(len(tokenizer))
lora_config = LoraConfig(
r = args['lora_r'],
lora_alpha = args['lora_alpha'],
target_modules = args['lora_target_modules'].split(","),
modules_to_save = args['lora_modules_to_save'].split(","),
lora_dropout = args['lora_dropout'],
bias = "none",
inference_mode = False,
task_type = TaskType.CAUSAL_LM
)
llama_model = get_peft_model(llama_model, lora_config)
for n, p in llama_model.named_parameters():
if "original_module" in n and any(module_name in n for module_name in lora_config.modules_to_save):
p.requires_grad = False
self.tokenizer = tokenizer
self.model = llama_model
item_json = os.path.join(args['data_path'], args['dataset'], args['dataset'] + ".item.json")
with open(item_json, 'r') as f:
self.item_texts = json.load(f)
self.rqvae = RQVAE(in_dim = config.hidden_size,
num_emb_list = args['num_emb_list'],
e_dim = args['e_dim'],
layers = args['layers'],
dropout_prob = args['dropout_prob'],
bn = args['bn'],
loss_type = args['loss_type'],
quant_loss_weight = args['quant_loss_weight'],
kmeans_init = args['kmeans_init'],
kmeans_iters = args['kmeans_iters'],
sk_epsilons = args['sk_epsilons'],
sk_iters = args['sk_iters'])
# self.projector = nn.Linear(args['e_dim'], config.hidden_size)
self.args = args
def rqvae_forward(self, inputs, targets, inters, item, task):
llama_model = self.model.get_decoder()
if task.lower() in ['seqrec', 'itemsearch']:
# inputs, targets, inters, item
# item-id to text
inter_feature_list = []
inter_emb_list = []
inter_item_list = inters.split(',')
for j in range(len(inter_item_list)):
inter_feature = self.item_texts[inter_item_list[j]]['title'] + ' ' + self.item_texts[inter_item_list[j]]['description']
inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask)
inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1)
inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True)
inter_emb_list.append(inter_emb.detach())
inter_embs = torch.cat(inter_emb_list, dim = 0)
item_feature = self.item_texts[item]['title'] + ' ' + self.item_texts[item]['description']
item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask)
item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1)
item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True)
item_emb = item_emb.detach()
rec_embs, rq_loss, rqids = self.rqvae(torch.cat([inter_embs, item_emb], dim = 0))
rqvae_loss, rec_loss = self.rqvae.compute_loss(rec_embs, rq_loss, torch.cat([inter_embs, item_emb], dim = 0))
inters_rqids = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[:-1]
item_rqid = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[-1]
text_rqids = {}
code = ''
for rqid in inters_rqids:
for k, idx in enumerate(rqid):
code = code + self.prefix[k].format(idx)
code = code + ', '
text_rqids['inters'] = code[:-2]
code = ''
for k, idx in enumerate(item_rqid):
code = code + self.prefix[k].format(idx)
text_rqids['item'] = code
inputs = inputs.format(inters = text_rqids['inters'])
targets = targets.format(inters = text_rqids['inters'], item = text_rqids['item'])
elif task.lower() in ['inters2title','inters2description','preferenceobtain']:
# inputs, targets, inters
inter_feature_list = []
inter_emb_list = []
inter_item_list = inters.split(',')
for j in range(len(inter_item_list)):
inter_feature = self.item_texts[inter_item_list[j]]['title'] + ' ' + self.item_texts[inter_item_list[j]]['description']
inter_id = self.tokenizer(inter_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
inter_emb = llama_model(input_ids = inter_id.input_ids, attention_mask = inter_id.attention_mask)
inter_emb = inter_emb.last_hidden_state * inter_id.attention_mask.unsqueeze(-1)
inter_emb = inter_emb.sum(dim=1) / inter_id.attention_mask.sum(dim = -1, keepdim = True)
inter_emb_list.append(inter_emb.detach())
inter_embs = torch.cat(inter_emb_list, dim = 0)
rec_embs, rq_loss, rqids = self.rqvae(inter_embs)
rqvae_loss, rec_loss = self.rqvae.compute_loss(rec_embs, rq_loss, inter_embs)
inters_rqids = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()
code = ''
for rqid in inters_rqids:
for k, idx in enumerate(rqid):
code = code + self.prefix[k].format(idx)
code = code + ', '
inputs = inputs.format(inters = code[:-2])
targets = targets.format(inters = code[:-2])
elif task.lower() in ['item2index','index2item','intertitles2item','query2item']:
# inputs, targets, item
item_feature = self.item_texts[item]['title'] + ' ' + self.item_texts[item]['description']
item_ids = self.tokenizer(item_feature, return_tensors = 'pt', padding=True, truncation=True).to(self.model.device)
item_emb = llama_model(input_ids = item_ids.input_ids, attention_mask = item_ids.attention_mask)
item_emb = item_emb.last_hidden_state * item_ids.attention_mask.unsqueeze(-1)
item_emb = item_emb.sum(dim=1) / item_ids.attention_mask.sum(dim = -1, keepdim = True)
item_emb = item_emb.detach()
rec_embs, rq_loss, rqids = self.rqvae(item_emb)
rqvae_loss, rec_loss = self.rqvae.compute_loss(rec_embs, rq_loss, item_emb)
item_rqid = rqids.view(-1, rqids.shape[-1]).cpu().numpy().tolist()[0]
code = ''
for k, idx in enumerate(item_rqid):
code = code + self.prefix[k].format(idx)
targets = targets.format(item = code)
else:
raise NotImplementedError
return inputs, targets, rqvae_loss, rec_embs.shape[0]
def forward(self, input_ids, labels, inters, item, task):
'''
'input_ids':
[
"Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Using the user's historical interactions as input data, suggest the next item that the user is highly likely to enjoy.
The historical interactions are provided as follows: {inters}.
### Response:",
'Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
You have obtained the ordered list of user historical interaction items, which is as follows: {inters}.
Using this history as a reference, please select the next item to recommend to the user.
### Response:'
],
'labels':
[
"Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
Using the user's historical interactions as input data, suggest the next item that the user is highly likely to enjoy.
The historical interactions are provided as follows: {inters}.
### Response:{item}",
'Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
You have obtained the ordered list of user historical interaction items, which is as follows: {inters}.
Using this history as a reference, please select the next item to recommend to the user.
### Response:{item}'
],
'inters': ['0', '0,1'],
'item': ['1', '2'],
'task': ['seqrec', 'seqrec']
'''
assert len(set([len(input_ids), len(labels), len(inters), len(item), len(task)])) == 1
num_data = len(task)
total_rqvae_loss = 0
total_num_sample = 0
for i in range(num_data):
input_ids[i], labels[i], rqvae_loss, num_sample = self.rqvae_forward(input_ids[i], labels[i], inters[i], item[i], task[i])
total_rqvae_loss += rqvae_loss
total_num_sample += num_sample
input_data = self.tokenizer(
text = labels,
text_target = input_ids,
return_tensors = 'pt',
padding = 'longest',
truncation = True,
max_length = self.tokenizer.model_max_length,
return_attention_mask = True
).to(self.model.device)
labels = copy.deepcopy(input_data["input_ids"])
if self.args['only_train_response']:
labels[labels == self.tokenizer.pad_token_id] = -100
labels[torch.where(input_data["labels"] != self.tokenizer.pad_token_id)] = -100
input_data["labels"] = labels
# codebook_embedding = []
# for i in range(len(self.rqvae.num_emb_list)):
# codebook_embedding.append(self.rqvae.rq.vq_layers[i].embedding.weight.data)
# codebook_embedding = torch.cat(codebook_embedding, dim = 0)
# codebook_embedding = self.projector(codebook_embedding)
# self.model.model.model.embed_tokens.weight.data[-codebook_embedding.shape[0]:] = codebook_embedding
result = self.model(**input_data)
wandb.log({'Llama_Loss': result.loss, 'RQVAE_Loss': total_rqvae_loss / total_num_sample})
result.loss += total_rqvae_loss / total_num_sample
wandb.log({'Total_Loss': result.loss})
return result
def floating_point_ops(self, inputs):
return 0