|
|
import os
|
|
|
import collections
|
|
|
import json
|
|
|
import logging
|
|
|
import argparse
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
import torch
|
|
|
from time import time
|
|
|
from torch import optim
|
|
|
from tqdm import tqdm
|
|
|
import torch.utils.data as data
|
|
|
from torch.utils.data import DataLoader
|
|
|
from index.models.rqvae import RQVAE
|
|
|
|
|
|
|
|
|
import random
|
|
|
|
|
|
class NpyDataset(data.Dataset):
|
|
|
def __init__(self, data_path):
|
|
|
self.data_path = data_path
|
|
|
self.embeddings = np.load(data_path)
|
|
|
self.dim = self.embeddings.shape[-1]
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
emb = self.embeddings[index]
|
|
|
tensor_emb = torch.FloatTensor(emb)
|
|
|
return tensor_emb
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.embeddings)
|
|
|
|
|
|
def if_collided(all_indices_str):
|
|
|
tot_item = len(all_indices_str)
|
|
|
tot_indice = len(set(all_indices_str.tolist()))
|
|
|
return tot_item == tot_indice
|
|
|
|
|
|
def get_indices_count(all_indices_str):
|
|
|
indices_count = collections.defaultdict(int)
|
|
|
for index in all_indices_str:
|
|
|
indices_count[index] += 1
|
|
|
return indices_count
|
|
|
|
|
|
def get_collision_item(all_indices_str):
|
|
|
index2id = {}
|
|
|
for i, index in enumerate(all_indices_str):
|
|
|
if index not in index2id:
|
|
|
index2id[index] = []
|
|
|
index2id[index].append(i)
|
|
|
collision_item_groups = []
|
|
|
for index in index2id:
|
|
|
if len(index2id[index]) > 1:
|
|
|
collision_item_groups.append(index2id[index])
|
|
|
return collision_item_groups
|
|
|
|
|
|
def parse_args():
|
|
|
parser = argparse.ArgumentParser(description = "Index")
|
|
|
parser.add_argument("--item_model_path", type = str, default = "", help = "")
|
|
|
parser.add_argument("--item_data_path", type = str, default = "", help = "")
|
|
|
parser.add_argument("--user_model_path", type = str, default = "", help = "")
|
|
|
parser.add_argument("--user_data_path", type = str, default = "", help = "")
|
|
|
|
|
|
parser.add_argument("--device", type = str, default = "cuda:0", help = "gpu or cpu")
|
|
|
return parser.parse_args()
|
|
|
|
|
|
generate_args = parse_args()
|
|
|
print(generate_args)
|
|
|
|
|
|
device = torch.device(generate_args.device)
|
|
|
|
|
|
|
|
|
ckpt = torch.load(os.path.join(generate_args.item_model_path, 'best_collision_model.pth'), map_location = torch.device('cpu'))
|
|
|
args = ckpt['args']
|
|
|
state_dict = ckpt['state_dict']
|
|
|
|
|
|
data = NpyDataset(generate_args.item_data_path)
|
|
|
data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prefix = ["<a_{}>","<b_{}>","<c_{}>","<d_{}>","<e_{}>"]
|
|
|
postfix = "<p_{}>"
|
|
|
|
|
|
index_table = {}
|
|
|
all_indices = []
|
|
|
all_indices_str = []
|
|
|
with torch.no_grad():
|
|
|
for x in tqdm(data_loader):
|
|
|
|
|
|
|
|
|
|
|
|
indices = np.random.randint(0, 256, size = (64, 4), dtype = int)
|
|
|
for index in indices:
|
|
|
code = []
|
|
|
for i, ind in enumerate(index):
|
|
|
code.append(prefix[i].format(int(ind)))
|
|
|
|
|
|
if str(code) in index_table:
|
|
|
index_table[str(code)] += 1
|
|
|
else:
|
|
|
index_table[str(code)] = 0
|
|
|
code.append(postfix.format(index_table[str(code)]))
|
|
|
|
|
|
all_indices.append(code)
|
|
|
all_indices_str.append(str(code))
|
|
|
|
|
|
all_indices = np.array(all_indices)
|
|
|
all_indices_str = np.array(all_indices_str)
|
|
|
|
|
|
print("All indices number: ", len(all_indices))
|
|
|
print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
|
|
|
print('Re-index number:', max(index_table.values()))
|
|
|
|
|
|
all_indices_dict = {}
|
|
|
for item, indices in enumerate(all_indices.tolist()):
|
|
|
all_indices_dict[item] = list(indices)
|
|
|
reindex_dict = {'reindex': max(index_table.values())}
|
|
|
|
|
|
item_index_path = os.path.join(generate_args.item_model_path, 'indices.random.item.json')
|
|
|
with open(item_index_path, 'w', encoding = 'utf-8') as f:
|
|
|
json.dump(all_indices_dict, f)
|
|
|
|
|
|
item_reindex_path = os.path.join(generate_args.item_model_path, 'reindex.random.item.json')
|
|
|
with open(item_reindex_path, 'w', encoding = 'utf-8') as f:
|
|
|
json.dump(reindex_dict, f)
|
|
|
|
|
|
|
|
|
ckpt = torch.load(os.path.join(generate_args.user_model_path, 'best_collision_model.pth'), map_location = torch.device('cpu'))
|
|
|
args = ckpt['args']
|
|
|
state_dict = ckpt['state_dict']
|
|
|
|
|
|
data = NpyDataset(generate_args.user_data_path)
|
|
|
data_loader = DataLoader(data, num_workers = args.num_workers, batch_size = 64, shuffle = False, pin_memory = True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prefix = ['<z-{}>','<y-{}>','<x-{}>','<w-{}>','<v-{}>']
|
|
|
|
|
|
all_indices = []
|
|
|
all_indices_str = []
|
|
|
with torch.no_grad():
|
|
|
for x in tqdm(data_loader):
|
|
|
|
|
|
|
|
|
indices = np.random.randint(0, 256, size = (64, 4), dtype = int)
|
|
|
for index in indices:
|
|
|
code = []
|
|
|
for i, ind in enumerate(index):
|
|
|
code.append(prefix[i].format(int(ind)))
|
|
|
|
|
|
all_indices.append(code)
|
|
|
all_indices_str.append(str(code))
|
|
|
|
|
|
all_indices = np.array(all_indices)
|
|
|
all_indices_str = np.array(all_indices_str)
|
|
|
|
|
|
print("All indices number: ", len(all_indices))
|
|
|
print("Max number of conflicts: ", max(get_indices_count(all_indices_str).values()))
|
|
|
|
|
|
all_indices_dict = {}
|
|
|
for item, indices in enumerate(all_indices.tolist()):
|
|
|
all_indices_dict[item] = list(indices)
|
|
|
|
|
|
json_path = os.path.join(generate_args.user_model_path, 'indices.random.user.json')
|
|
|
with open(json_path, 'w', encoding = 'utf-8') as f:
|
|
|
json.dump(all_indices_dict, f) |