|
|
import os
|
|
|
import collections
|
|
|
import json
|
|
|
import logging
|
|
|
import argparse
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
import torch
|
|
|
from time import time
|
|
|
from torch import optim
|
|
|
from tqdm import tqdm
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
|
from rq_llama import *
|
|
|
from index.datasets import EmbDataset
|
|
|
|
|
|
def if_collided(all_indices_str):
|
|
|
tot_item = len(all_indices_str)
|
|
|
tot_indice = len(set(all_indices_str.tolist()))
|
|
|
return tot_item == tot_indice
|
|
|
|
|
|
def get_indices_count(all_indices_str):
|
|
|
indices_count = collections.defaultdict(int)
|
|
|
for index in all_indices_str:
|
|
|
indices_count[index] += 1
|
|
|
return indices_count
|
|
|
|
|
|
def get_collision_item(all_indices_str):
|
|
|
index2id = {}
|
|
|
for i, index in enumerate(all_indices_str):
|
|
|
if index not in index2id:
|
|
|
index2id[index] = []
|
|
|
index2id[index].append(i)
|
|
|
collision_item_groups = []
|
|
|
for index in index2id:
|
|
|
if len(index2id[index]) > 1:
|
|
|
collision_item_groups.append(index2id[index])
|
|
|
return collision_item_groups
|
|
|
|
|
|
def parse_args():
|
|
|
parser = argparse.ArgumentParser(description = "Index")
|
|
|
parser.add_argument("--ckpt_path", type = str, default = "", help = "")
|
|
|
parser.add_argument("--data_path", type = str, default = "", help = "")
|
|
|
parser.add_argument("--save_path", type = str, default = "", help = "")
|
|
|
parser.add_argument("--device_map", type = str, default = "1", help = "gpu or cpu")
|
|
|
return parser.parse_args()
|
|
|
|
|
|
args = parse_args()
|
|
|
print(args)
|
|
|
|
|
|
data = EmbDataset(args.data_path)
|
|
|
data_loader = DataLoader(data, num_workers = 4, batch_size = 64, shuffle = False, pin_memory = True)
|
|
|
device_map = {'': int(args.device_map)}
|
|
|
MODEL = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map)
|
|
|
MODEL.eval()
|
|
|
device = MODEL.device
|
|
|
rqvae = MODEL.rqvae
|
|
|
prefix = MODEL.prefix
|
|
|
|
|
|
all_indices = []
|
|
|
all_indices_str = []
|
|
|
with torch.no_grad():
|
|
|
for x in tqdm(data_loader):
|
|
|
indices = rqvae.get_indices(x.to(device), False)
|
|
|
indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
|
|
|
for index in indices:
|
|
|
code = []
|
|
|
for i, ind in enumerate(index):
|
|
|
code.append(prefix[i].format(int(ind)))
|
|
|
|
|
|
all_indices.append(code)
|
|
|
all_indices_str.append(str(code))
|
|
|
|
|
|
all_indices = np.array(all_indices)
|
|
|
all_indices_str = np.array(all_indices_str)
|
|
|
|
|
|
for vq in rqvae.rq.vq_layers[:-1]:
|
|
|
vq.sk_epsilon=0.0
|
|
|
if rqvae.rq.vq_layers[-1].sk_epsilon == 0.0:
|
|
|
rqvae.rq.vq_layers[-1].sk_epsilon = 0.003
|
|
|
|
|
|
tt = 0
|
|
|
while True:
|
|
|
if tt >= 20 or if_collided(all_indices_str):
|
|
|
break
|
|
|
|
|
|
collision_item_groups = get_collision_item(all_indices_str)
|
|
|
|
|
|
print(len(collision_item_groups))
|
|
|
with torch.no_grad():
|
|
|
for collision_items in collision_item_groups:
|
|
|
indices = rqvae.get_indices(data[collision_items].to(device), True)
|
|
|
indices = indices.view(-1, indices.shape[-1]).cpu().numpy()
|
|
|
for item, index in zip(collision_items, indices):
|
|
|
code = []
|
|
|
for i, ind in enumerate(index):
|
|
|
code.append(prefix[i].format(int(ind)))
|
|
|
|
|
|
all_indices[item] = code
|
|
|
all_indices_str[item] = str(code)
|
|
|
tt += 1
|
|
|
|
|
|
print("All indices number: ",len(all_indices))
|
|
|
print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
|
|
|
|
|
|
tot_item = len(all_indices_str)
|
|
|
tot_indice = len(set(all_indices_str.tolist()))
|
|
|
print("Collision Rate",(tot_item - tot_indice) / tot_item)
|
|
|
|
|
|
all_indices_dict = {}
|
|
|
for item, indices in enumerate(all_indices.tolist()):
|
|
|
all_indices_dict[item] = list(indices)
|
|
|
|
|
|
with open(args.save_path, 'w',encoding = 'utf-8') as f:
|
|
|
json.dump(all_indices_dict, f) |