resume-analyzer / Code+Folder /src /embedding_model.py
shekkari21's picture
updated scripts
1a1a2a1
from langchain_openai import OpenAIEmbeddings
import pickle
from constants import EMBEDDING_MODEL_NAME, OUTPUT_PATH
class EmbeddingModel:
"""
A class to handle the creation, saving, and reading of embeddings using OpenAI's embedding model.
"""
def __init__(self):
"""
Initializes the EmbeddingModel with the specified OpenAI embedding model and API key.
Uses the OPENAI_API_KEY environment variable automatically.
"""
self.embedding_model = OpenAIEmbeddings(model=EMBEDDING_MODEL_NAME)
def get_embeddings(self, data_dict):
"""
Generates embeddings for the given data.
Args:
data_dict (dict): A dictionary where the keys are identifiers and the values are the texts to be embedded.
Returns:
dict: A dictionary with the same keys as the input and the corresponding embeddings as values.
"""
output_dict = {}
keys = list(data_dict.keys())
values = list(data_dict.values())
embeddings = self.embedding_model.embed_documents(values)
for i in range(0, len(keys)):
output_dict[keys[i]] = embeddings[i]
return output_dict
@staticmethod
def save_embeddings(embedding, file_name):
"""
Saves the given embeddings to a file.
Args:
embedding (dict): The embeddings to be saved.
file_name (str): The name of the file to save the embeddings to.
"""
with open(OUTPUT_PATH + file_name, 'wb') as handle:
pickle.dump(embedding, handle, protocol=pickle.HIGHEST_PROTOCOL)
@staticmethod
def read_embeddings(file_name):
"""
Reads embeddings from a file.
Args:
file_name (str): The name of the file to read the embeddings from.
Returns:
dict: The embeddings read from the file.
"""
with open(OUTPUT_PATH + file_name, 'rb') as handle:
output_dict = pickle.load(handle)
return output_dict