File size: 947 Bytes
811e03d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
import numpy as np
import torch
import torch.utils.data as data
import pandas as pd
from tqdm import tqdm
class EmbDataset(data.Dataset):
def __init__(self,data_path):
self.data_path = data_path
names = ['emb']
usecols = [1]
tsv_data = pd.read_csv(data_path, sep = '\t',usecols = usecols, names = names, quotechar = None, quoting = 3)
features = tsv_data['emb'].values.tolist()
num_data = len(features)
for i in tqdm(range(num_data)):
features[i] = [float(s) for s in features[i].split(' ')]
self.embeddings = np.array(features, dtype = np.float16)
assert self.embeddings.shape[0] == num_data
self.dim = self.embeddings.shape[-1]
def __getitem__(self, index):
emb = self.embeddings[index]
tensor_emb = torch.tensor(emb, dtype = torch.float16)
return tensor_emb
def __len__(self):
return len(self.embeddings)
|