File size: 947 Bytes
811e03d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
import torch
import torch.utils.data as data
import pandas as pd
from tqdm import tqdm

class EmbDataset(data.Dataset):
    def __init__(self,data_path):
        self.data_path = data_path
        names = ['emb']
        usecols = [1]
        tsv_data = pd.read_csv(data_path, sep = '\t',usecols = usecols, names = names, quotechar = None, quoting = 3)
        features = tsv_data['emb'].values.tolist()
        num_data = len(features)
        for i in tqdm(range(num_data)):
            features[i] = [float(s) for s in features[i].split(' ')]
        self.embeddings = np.array(features, dtype = np.float16)
        assert self.embeddings.shape[0] == num_data
        self.dim = self.embeddings.shape[-1]

    def __getitem__(self, index):
        emb = self.embeddings[index]
        tensor_emb = torch.tensor(emb, dtype = torch.float16)
        return tensor_emb

    def __len__(self):
        return len(self.embeddings)