trygithubactions / src /genai /utils /load_embeddings.py
subashpoudel's picture
fixed naming convention
34b6a10
raw
history blame
1.32 kB
import numpy as np
import ast
import faiss
import pandas as pd
from datasets import load_dataset
def load_caption_index():
dataset = load_dataset("DvorakInnovationAI/rt-genai-dataset-v1", revision="openai-embeddings")
df = dataset["train"]
df= df.to_pandas()
df['embeddings'] = df['embeddings'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)
embeddings = np.vstack(df['embeddings'].values).astype('float32')
faiss.normalize_L2(embeddings)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
return df, embeddings, index
def load_imdb_ideas_index():
dataset = load_dataset("DvorakInnovationAI/rt-genai-imdb-ideas-v1", revision='openai-embeddings')
df = dataset['train']
df= df.to_pandas()
df['embeddings'] = df['embeddings'].apply(lambda x: ast.literal_eval(x) if isinstance(x,str) else x)
embeddings = np.vstack(df['embeddings'].values).astype('float32')
faiss.normalize_L2(embeddings)
index = faiss.IndexFlatIP(embeddings.shape[1])
index.add(embeddings)
return df , embeddings , index
print('Loading Embeddings...........')
caption_df, caption_embeddings, caption_index = load_caption_index()
ideas_df , ideas_embeddings , ideas_index = load_imdb_ideas_index()
print('Embeddings Loaded.................')