Benchmark-Dual / index /datasets.py
Junyin's picture
Add files using upload-large-folder tool
936459a verified
raw
history blame contribute delete
947 Bytes
import numpy as np
import torch
import torch.utils.data as data
import pandas as pd
from tqdm import tqdm
class EmbDataset(data.Dataset):
def __init__(self,data_path):
self.data_path = data_path
names = ['emb']
usecols = [1]
tsv_data = pd.read_csv(data_path, sep = '\t',usecols = usecols, names = names, quotechar = None, quoting = 3)
features = tsv_data['emb'].values.tolist()
num_data = len(features)
for i in tqdm(range(num_data)):
features[i] = [float(s) for s in features[i].split(' ')]
self.embeddings = np.array(features, dtype = np.float16)
assert self.embeddings.shape[0] == num_data
self.dim = self.embeddings.shape[-1]
def __getitem__(self, index):
emb = self.embeddings[index]
tensor_emb = torch.tensor(emb, dtype = torch.float16)
return tensor_emb
def __len__(self):
return len(self.embeddings)