import os import collections import json import logging import argparse import numpy as np import pandas as pd import torch from time import time from torch import optim from tqdm import tqdm from torch.utils.data import DataLoader from rq_llama import * from index.datasets import EmbDataset def if_collided(all_indices_str): tot_item = len(all_indices_str) tot_indice = len(set(all_indices_str.tolist())) return tot_item == tot_indice def get_indices_count(all_indices_str): indices_count = collections.defaultdict(int) for index in all_indices_str: indices_count[index] += 1 return indices_count def get_collision_item(all_indices_str): index2id = {} for i, index in enumerate(all_indices_str): if index not in index2id: index2id[index] = [] index2id[index].append(i) collision_item_groups = [] for index in index2id: if len(index2id[index]) > 1: collision_item_groups.append(index2id[index]) return collision_item_groups def parse_args(): parser = argparse.ArgumentParser(description = "Index") parser.add_argument("--ckpt_path", type = str, default = "", help = "") parser.add_argument("--data_path", type = str, default = "", help = "") parser.add_argument("--save_path", type = str, default = "", help = "") parser.add_argument("--device_map", type = str, default = "1", help = "gpu or cpu") return parser.parse_args() args = parse_args() print(args) data = EmbDataset(args.data_path) data_loader = DataLoader(data, num_workers = 4, batch_size = 64, shuffle = False, pin_memory = True) device_map = {'': int(args.device_map)} MODEL = LlamaWithRQ.from_pretrained(args.ckpt_path, torch_dtype = torch.float16, low_cpu_mem_usage = True, device_map = device_map) MODEL.eval() device = MODEL.device rqvae = MODEL.rqvae prefix = MODEL.prefix all_indices = [] all_indices_str = [] with torch.no_grad(): for x in tqdm(data_loader): indices = rqvae.get_indices(x.to(device), False) indices = indices.view(-1, indices.shape[-1]).cpu().numpy() for index in indices: code = [] for i, ind in enumerate(index): code.append(prefix[i].format(int(ind))) all_indices.append(code) all_indices_str.append(str(code)) all_indices = np.array(all_indices) all_indices_str = np.array(all_indices_str) for vq in rqvae.rq.vq_layers[:-1]: vq.sk_epsilon=0.0 if rqvae.rq.vq_layers[-1].sk_epsilon == 0.0: rqvae.rq.vq_layers[-1].sk_epsilon = 0.003 tt = 0 while True: if tt >= 20 or if_collided(all_indices_str): break collision_item_groups = get_collision_item(all_indices_str) # print(collision_item_groups) print(len(collision_item_groups)) with torch.no_grad(): for collision_items in collision_item_groups: indices = rqvae.get_indices(data[collision_items].to(device), True) indices = indices.view(-1, indices.shape[-1]).cpu().numpy() for item, index in zip(collision_items, indices): code = [] for i, ind in enumerate(index): code.append(prefix[i].format(int(ind))) all_indices[item] = code all_indices_str[item] = str(code) tt += 1 print("All indices number: ",len(all_indices)) print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values())) tot_item = len(all_indices_str) tot_indice = len(set(all_indices_str.tolist())) print("Collision Rate",(tot_item - tot_indice) / tot_item) all_indices_dict = {} for item, indices in enumerate(all_indices.tolist()): all_indices_dict[item] = list(indices) with open(args.save_path, 'w',encoding = 'utf-8') as f: json.dump(all_indices_dict, f)