| import numpy as np | |
| import torch | |
| import torch.utils.data as data | |
| class EmbDataset(data.Dataset): | |
| def __init__(self,data_path): | |
| self.data_path = data_path | |
| # self.embeddings = np.fromfile(data_path, dtype=np.float32).reshape(16859,-1) | |
| self.embeddings = np.load(data_path) | |
| self.dim = self.embeddings.shape[-1] | |
| def __getitem__(self, index): | |
| emb = self.embeddings[index] | |
| tensor_emb=torch.FloatTensor(emb) | |
| return tensor_emb | |
| def __len__(self): | |
| return len(self.embeddings) | |