| | import logging |
| | import os |
| | import yaml |
| |
|
| | from modules.embedding_model_loader import EmbeddingModelLoader |
| | from langchain.vectorstores import FAISS |
| | from modules.data_loader import DataLoader |
| | from modules.constants import * |
| | from modules.helpers import * |
| |
|
| |
|
| | class VectorDB: |
| | def __init__(self, config, logger=None): |
| | self.config = config |
| | self.db_option = config["embedding_options"]["db_option"] |
| | self.document_names = None |
| | self.webpage_crawler = WebpageCrawler() |
| |
|
| | |
| | if logger is None: |
| | self.logger = logging.getLogger(__name__) |
| | self.logger.setLevel(logging.INFO) |
| |
|
| | |
| | console_handler = logging.StreamHandler() |
| | console_handler.setLevel(logging.INFO) |
| | formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") |
| | console_handler.setFormatter(formatter) |
| | self.logger.addHandler(console_handler) |
| |
|
| | |
| | log_file_path = "vector_db.log" |
| | file_handler = logging.FileHandler(log_file_path, mode="w") |
| | file_handler.setLevel(logging.INFO) |
| | file_handler.setFormatter(formatter) |
| | self.logger.addHandler(file_handler) |
| | else: |
| | self.logger = logger |
| |
|
| | self.logger.info("VectorDB instance instantiated") |
| |
|
| | def load_files(self): |
| | files = os.listdir(self.config["embedding_options"]["data_path"]) |
| | files = [ |
| | os.path.join(self.config["embedding_options"]["data_path"], file) |
| | for file in files |
| | ] |
| | urls = get_urls_from_file(self.config["embedding_options"]["url_file_path"]) |
| | if self.config["embedding_options"]["expand_urls"]: |
| | all_urls = [] |
| | for url in urls: |
| | base_url = get_base_url(url) |
| | all_urls.extend(self.webpage_crawler.get_all_pages(url, base_url)) |
| | urls = all_urls |
| | return files, urls |
| |
|
| | def create_embedding_model(self): |
| | self.logger.info("Creating embedding function") |
| | self.embedding_model_loader = EmbeddingModelLoader(self.config) |
| | self.embedding_model = self.embedding_model_loader.load_embedding_model() |
| |
|
| | def initialize_database(self, document_chunks: list, document_names: list): |
| | |
| | self.logger.info("Initializing vector_db") |
| | self.logger.info("\tUsing {} as db_option".format(self.db_option)) |
| | if self.db_option == "FAISS": |
| | self.vector_db = FAISS.from_documents( |
| | documents=document_chunks, embedding=self.embedding_model |
| | ) |
| | self.logger.info("Completed initializing vector_db") |
| |
|
| | def create_database(self): |
| | data_loader = DataLoader(self.config) |
| | self.logger.info("Loading data") |
| | files, urls = self.load_files() |
| | document_chunks, document_names = data_loader.get_chunks(files, urls) |
| | self.logger.info("Completed loading data") |
| |
|
| | self.create_embedding_model() |
| | self.initialize_database(document_chunks, document_names) |
| |
|
| | def save_database(self): |
| | self.vector_db.save_local( |
| | os.path.join( |
| | self.config["embedding_options"]["db_path"], |
| | "db_" |
| | + self.config["embedding_options"]["db_option"] |
| | + "_" |
| | + self.config["embedding_options"]["model"], |
| | ) |
| | ) |
| | self.logger.info("Saved database") |
| |
|
| | def load_database(self): |
| | self.create_embedding_model() |
| | self.vector_db = FAISS.load_local( |
| | os.path.join( |
| | self.config["embedding_options"]["db_path"], |
| | "db_" |
| | + self.config["embedding_options"]["db_option"] |
| | + "_" |
| | + self.config["embedding_options"]["model"], |
| | ), |
| | self.embedding_model, |
| | ) |
| | self.logger.info("Loaded database") |
| | return self.vector_db |
| |
|
| |
|
| | if __name__ == "__main__": |
| | with open("config.yml", "r") as f: |
| | config = yaml.safe_load(f) |
| | print(config) |
| | vector_db = VectorDB(config) |
| | vector_db.create_database() |
| | vector_db.save_database() |
| |
|