Spaces:
Sleeping
Sleeping
File size: 2,041 Bytes
9607899 1a1a2a1 9607899 1a1a2a1 9607899 1a1a2a1 9607899 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 | from langchain_openai import OpenAIEmbeddings
import pickle
from constants import EMBEDDING_MODEL_NAME, OUTPUT_PATH
class EmbeddingModel:
"""
A class to handle the creation, saving, and reading of embeddings using OpenAI's embedding model.
"""
def __init__(self):
"""
Initializes the EmbeddingModel with the specified OpenAI embedding model and API key.
Uses the OPENAI_API_KEY environment variable automatically.
"""
self.embedding_model = OpenAIEmbeddings(model=EMBEDDING_MODEL_NAME)
def get_embeddings(self, data_dict):
"""
Generates embeddings for the given data.
Args:
data_dict (dict): A dictionary where the keys are identifiers and the values are the texts to be embedded.
Returns:
dict: A dictionary with the same keys as the input and the corresponding embeddings as values.
"""
output_dict = {}
keys = list(data_dict.keys())
values = list(data_dict.values())
embeddings = self.embedding_model.embed_documents(values)
for i in range(0, len(keys)):
output_dict[keys[i]] = embeddings[i]
return output_dict
@staticmethod
def save_embeddings(embedding, file_name):
"""
Saves the given embeddings to a file.
Args:
embedding (dict): The embeddings to be saved.
file_name (str): The name of the file to save the embeddings to.
"""
with open(OUTPUT_PATH + file_name, 'wb') as handle:
pickle.dump(embedding, handle, protocol=pickle.HIGHEST_PROTOCOL)
@staticmethod
def read_embeddings(file_name):
"""
Reads embeddings from a file.
Args:
file_name (str): The name of the file to read the embeddings from.
Returns:
dict: The embeddings read from the file.
"""
with open(OUTPUT_PATH + file_name, 'rb') as handle:
output_dict = pickle.load(handle)
return output_dict
|