Spaces:
Sleeping
Sleeping
| from datasets import load_dataset | |
| import pandas as pd | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.document_loaders import CSVLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain.embeddings import CacheBackedEmbeddings | |
| from langchain.storage import LocalFileStore | |
| from pathlib import Path | |
| from functools import reduce | |
| import os | |
| def build_db(openai_api_key): | |
| data = load_imdb_data() | |
| (embedder, embedding_model) = create_embedder(openai_api_key) | |
| build_vectore_store(data, embedder) | |
| vector_store = load_vector_store(embedder) | |
| run_test_query(embedding_model, vector_store) | |
| def load_imdb_data(): | |
| print("Loading IMDB dataset") | |
| dataset = load_dataset("ShubhamChoksi/IMDB_Movies") | |
| dataset_dict = dataset | |
| dataset_dict["train"].to_csv('data/imdb.csv') | |
| print("") | |
| print("Creating dataframe") | |
| movies_dataframe = pd.read_csv('data/imdb.csv') | |
| print(movies_dataframe.head()) | |
| print("") | |
| print("Loading data from CSV") | |
| loader = CSVLoader(file_path='data/imdb.csv') | |
| data = loader.load() | |
| print("Done loading data...") | |
| print("Length: " + str(len(data))) # ensure we have actually loaded data into a format LangChain can recognize | |
| print("Data list type: " + str(type(data))) | |
| print("Data type: " + str(type(data[0]))) | |
| print(data[0]) | |
| print("") | |
| print("Calculating total length of data") | |
| add_length = lambda sum, doc: len(doc.page_content) + sum | |
| total_length = reduce(add_length, data, 0) | |
| print("Total number of characters in dataset: " + str(total_length)) | |
| print("Total divided by 1,000: " + str(total_length / 1000)) | |
| print("") | |
| return data | |
| def create_embedder(openai_api_key): | |
| embedding_model = OpenAIEmbeddings(openai_api_key=openai_api_key) | |
| # Create the embedding store file if it doesn't already exist. | |
| storeFile = str(Path.cwd() / 'data/embedding-store') | |
| # storeFilePath = Path(storeFile) | |
| # if not storeFilePath.exists(): | |
| # storeFilePath.touch() | |
| # Create the embedder, using a local file store as the backing store. | |
| store = LocalFileStore(storeFile) | |
| embedder = CacheBackedEmbeddings.from_bytes_store( | |
| embedding_model, | |
| store | |
| ) | |
| return (embedder, embedding_model) | |
| def build_vectore_store(data, embedder): | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=100, | |
| ) | |
| chunked_documents = text_splitter.split_documents(data) | |
| len(chunked_documents) # ensure we have actually split the data into chunks | |
| print("Trying to load vector store from file...") | |
| vector_store = None | |
| try: | |
| os.makedirs("data/week2-movies", exist_ok=True) | |
| vector_store = FAISS.load_local("data/week2-movies", embedder, allow_dangerous_deserialization=True) | |
| except Exception as e: | |
| vector_store = None | |
| if vector_store is None: | |
| print("No local vector store found - computing a new one...") | |
| vector_store = FAISS.from_documents(data, embedder) | |
| print("Done computing new vectore store. Saving to local file.") | |
| vector_store.save_local("data/week2-movies") | |
| else: | |
| print("Found vector store in local file. Using that.") | |
| print("") | |
| def load_vector_store(embedder): | |
| vector_store = FAISS.load_local("data/week2-movies", embedder, allow_dangerous_deserialization=True) | |
| return vector_store | |
| def run_test_query(embedding_model, vector_store): | |
| print("Verifying that we can query the vectore dB...") | |
| query = "I have a need. A need for speed." | |
| embedded_query = embedding_model.embed_query(query) | |
| similar_documents = vector_store.similarity_search_by_vector(embedded_query) | |
| for page in similar_documents: | |
| print(str(page.page_content)) | |
| print("") | |
| if __name__ == "__main__": | |
| openai_api_key = os.getenv("openai_api_key") | |
| build_db(openai_api_key) |