| import numpy as np | |
| import torch | |
| import torch.utils.data as data | |
| import pandas as pd | |
| from tqdm import tqdm | |
| class EmbDataset(data.Dataset): | |
| def __init__(self,data_path): | |
| self.data_path = data_path | |
| names = ['emb'] | |
| usecols = [1] | |
| tsv_data = pd.read_csv(data_path, sep = '\t',usecols = usecols, names = names, quotechar = None, quoting = 3) | |
| features = tsv_data['emb'].values.tolist() | |
| num_data = len(features) | |
| for i in tqdm(range(num_data)): | |
| features[i] = [float(s) for s in features[i].split(' ')] | |
| self.embeddings = np.array(features, dtype = np.float16) | |
| assert self.embeddings.shape[0] == num_data | |
| self.dim = self.embeddings.shape[-1] | |
| def __getitem__(self, index): | |
| emb = self.embeddings[index] | |
| tensor_emb = torch.tensor(emb, dtype = torch.float16) | |
| return tensor_emb | |
| def __len__(self): | |
| return len(self.embeddings) | |