File size: 619 Bytes
2cc98e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import json
import pandas as pd
import os


class EmbeddingsDataLoader:
    def __init__(
        self,
        filepath=None,
    ):
        if filepath is None:
            # Use relative path from this file's location
            current_dir = os.path.dirname(os.path.abspath(__file__))
            filepath = os.path.join(current_dir, '..', 'prefetched', 'embeddings.csv.gz')
        self.filepath = filepath

    def load(self) -> pd.DataFrame:
        df = pd.read_csv(self.filepath, compression='gzip')
        df.embedding = df.embedding.apply(lambda vec: [float(v) for v in json.loads(vec)])
        return df