| import os
|
| import collections
|
| import json
|
| import logging
|
| import argparse
|
| import numpy as np
|
| import pandas as pd
|
| import torch
|
| from time import time
|
| from torch import optim
|
| from tqdm import tqdm
|
| from torch.utils.data import DataLoader
|
|
|
| from rq_llama import *
|
| from index.datasets import EmbDataset
|
|
|
| def if_collided(all_indices_str):
|
| tot_item = len(all_indices_str)
|
| tot_indice = len(set(all_indices_str.tolist()))
|
| return tot_item == tot_indice
|
|
|
| def get_indices_count(all_indices_str):
|
| indices_count = collections.defaultdict(int)
|
| for index in all_indices_str:
|
| indices_count[index] += 1
|
| return indices_count
|
|
|
| def get_collision_item(all_indices_str):
|
| index2id = {}
|
| for i, index in enumerate(all_indices_str):
|
| if index not in index2id:
|
| index2id[index] = []
|
| index2id[index].append(i)
|
| collision_item_groups = []
|
| for index in index2id:
|
| if len(index2id[index]) > 1:
|
| collision_item_groups.append(index2id[index])
|
| return collision_item_groups
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser(description = "Index")
|
| parser.add_argument("--ckpt_path", type = str, default = "", help = "")
|
| parser.add_argument("--data_path", type = str, default = "", help = "")
|
| parser.add_argument("--save_path", type = str, default = "", help = "")
|
| parser.add_argument("--device_map", type = str, default = "1", help = "gpu or cpu")
|
| return parser.parse_args()
|
|
|
| args = parse_args()
|
| print(args)
|
|
|
| data = EmbDataset(args.data_path)
|
| data_loader = DataLoader(data, num_workers = 4, batch_size = 64, shuffle = False, pin_memory = True)
|
| device_map = {'': int(args.device_map)}
|
| MODEL = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map)
|
| MODEL.eval()
|
| device = MODEL.device
|
| rqvae = MODEL.rqvae
|
| prefix = MODEL.prefix
|
|
|
| all_indices = []
|
| all_indices_str = []
|
| with torch.no_grad():
|
| for x in tqdm(data_loader):
|
| indices = rqvae.get_indices(x.to(device), False)
|
| indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
|
| for index in indices:
|
| code = []
|
| for i, ind in enumerate(index):
|
| code.append(prefix[i].format(int(ind)))
|
|
|
| all_indices.append(code)
|
| all_indices_str.append(str(code))
|
|
|
| all_indices = np.array(all_indices)
|
| all_indices_str = np.array(all_indices_str)
|
|
|
| for vq in rqvae.rq.vq_layers[:-1]:
|
| vq.sk_epsilon=0.0
|
| if rqvae.rq.vq_layers[-1].sk_epsilon == 0.0:
|
| rqvae.rq.vq_layers[-1].sk_epsilon = 0.003
|
|
|
| tt = 0
|
| while True:
|
| if tt >= 20 or if_collided(all_indices_str):
|
| break
|
|
|
| collision_item_groups = get_collision_item(all_indices_str)
|
|
|
| print(len(collision_item_groups))
|
| with torch.no_grad():
|
| for collision_items in collision_item_groups:
|
| indices = rqvae.get_indices(data[collision_items].to(device), True)
|
| indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
|
| for item, index in zip(collision_items, indices):
|
| code = []
|
| for i, ind in enumerate(index):
|
| code.append(prefix[i].format(int(ind)))
|
|
|
| all_indices[item] = code
|
| all_indices_str[item] = str(code)
|
| tt += 1
|
|
|
| print("All indices number: ",len(all_indices))
|
| print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
|
|
|
| tot_item = len(all_indices_str)
|
| tot_indice = len(set(all_indices_str.tolist()))
|
| print("Collision Rate",(tot_item - tot_indice) / tot_item)
|
|
|
| all_indices_dict = {}
|
| for item, indices in enumerate(all_indices.tolist()):
|
| all_indices_dict[item] = list(indices)
|
|
|
| with open(args.save_path, 'w',encoding = 'utf-8') as f:
|
| json.dump(all_indices_dict, f) |