import os import collections import json import logging import argparse import numpy as np import pandas as pd import torch from time import time from torch import optim from tqdm import tqdm import torch.utils.data as data from torch.utils.data import DataLoader from index.models.rqvae import RQVAE # from rq_llama import * # from index.datasets import EmbDataset import random class NpyDataset(data.Dataset): def __init__(self, data_path): self.data_path = data_path self.embeddings = np.load(data_path) self.dim = self.embeddings.shape[-1] def __getitem__(self, index): emb = self.embeddings[index] tensor_emb = torch.FloatTensor(emb) return tensor_emb def __len__(self): return len(self.embeddings) def if_collided(all_indices_str): tot_item = len(all_indices_str) tot_indice = len(set(all_indices_str.tolist())) return tot_item == tot_indice def get_indices_count(all_indices_str): indices_count = collections.defaultdict(int) for index in all_indices_str: indices_count[index] += 1 return indices_count def get_collision_item(all_indices_str): index2id = {} for i, index in enumerate(all_indices_str): if index not in index2id: index2id[index] = [] index2id[index].append(i) collision_item_groups = [] for index in index2id: if len(index2id[index]) > 1: collision_item_groups.append(index2id[index]) return collision_item_groups def parse_args(): parser = argparse.ArgumentParser(description = "Index") parser.add_argument("--item_model_path", type = str, default = "", help = "") parser.add_argument("--item_data_path", type = str, default = "", help = "") parser.add_argument("--user_model_path", type = str, default = "", help = "") parser.add_argument("--user_data_path", type = str, default = "", help = "") # parser.add_argument("--save_path", type = str, default = "", help = "") parser.add_argument("--device", type = str, default = "cuda:0", help = "gpu or cpu") return parser.parse_args() generate_args = parse_args() print(generate_args) device = torch.device(generate_args.device) # generate item index ckpt = torch.load(os.path.join(generate_args.item_model_path, 'best_collision_model.pth'), map_location = torch.device('cpu')) args = ckpt['args'] state_dict = ckpt['state_dict'] data = NpyDataset(generate_args.item_data_path) data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True) # model = RQVAE( # in_dim = data.dim, # num_emb_list = args.num_emb_list, # e_dim = args.e_dim, # layers = args.layers, # dropout_prob = args.dropout_prob, # bn = args.bn, # loss_type = args.loss_type, # quant_loss_weight = args.quant_loss_weight, # kmeans_init = args.kmeans_init, # kmeans_iters = args.kmeans_iters, # sk_epsilons = args.sk_epsilons, # sk_iters = args.sk_iters, # ) # model.load_state_dict(state_dict) # model = model.to(device) # model.eval() # print(model) prefix = ["","","","",""] postfix = "" index_table = {} all_indices = [] all_indices_str = [] with torch.no_grad(): for x in tqdm(data_loader): # indices = model.get_indices(x.to(device), False) # indices = indices.view(-1, indices.shape[-1]).cpu().numpy() indices = np.random.randint(0, 256, size = (64, 4), dtype = int) for index in indices: code = [] for i, ind in enumerate(index): code.append(prefix[i].format(int(ind))) if str(code) in index_table: index_table[str(code)] += 1 else: index_table[str(code)] = 0 code.append(postfix.format(index_table[str(code)])) all_indices.append(code) all_indices_str.append(str(code)) all_indices = np.array(all_indices) all_indices_str = np.array(all_indices_str) print("All indices number: ", len(all_indices)) print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values())) print('Re-index number:', max(index_table.values())) all_indices_dict = {} for item, indices in enumerate(all_indices.tolist()): all_indices_dict[item] = list(indices) reindex_dict = {'reindex': max(index_table.values())} item_index_path = os.path.join(generate_args.item_model_path, 'indices.random.item.json') with open(item_index_path, 'w', encoding = 'utf-8') as f: json.dump(all_indices_dict, f) item_reindex_path = os.path.join(generate_args.item_model_path, 'reindex.random.item.json') with open(item_reindex_path, 'w', encoding = 'utf-8') as f: json.dump(reindex_dict, f) # generate user index ckpt = torch.load(os.path.join(generate_args.user_model_path, 'best_collision_model.pth'), map_location = torch.device('cpu')) args = ckpt['args'] state_dict = ckpt['state_dict'] data = NpyDataset(generate_args.user_data_path) data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True) # model = RQVAE( # in_dim = data.dim, # num_emb_list = args.num_emb_list, # e_dim = args.e_dim, # layers = args.layers, # dropout_prob = args.dropout_prob, # bn = args.bn, # loss_type = args.loss_type, # quant_loss_weight = args.quant_loss_weight, # kmeans_init = args.kmeans_init, # kmeans_iters = args.kmeans_iters, # sk_epsilons = args.sk_epsilons, # sk_iters = args.sk_iters, # ) # model.load_state_dict(state_dict) # model = model.to(device) # model.eval() # print(model) prefix = ['','','','',''] all_indices = [] all_indices_str = [] with torch.no_grad(): for x in tqdm(data_loader): # indices = rqvae.get_indices(x.to(device), False) # indices = indices.view(-1, indices.shape[-1]).cpu().numpy() indices = np.random.randint(0, 256, size = (64, 4), dtype = int) for index in indices: code = [] for i, ind in enumerate(index): code.append(prefix[i].format(int(ind))) all_indices.append(code) all_indices_str.append(str(code)) all_indices = np.array(all_indices) all_indices_str = np.array(all_indices_str) print("All indices number: ", len(all_indices)) print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values())) all_indices_dict = {} for item, indices in enumerate(all_indices.tolist()): all_indices_dict[item] = list(indices) json_path = os.path.join(generate_args.user_model_path, 'indices.random.user.json') with open(json_path, 'w', encoding = 'utf-8') as f: json.dump(all_indices_dict, f)