Spaces:
Sleeping
Sleeping
| import os | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_mongodb.vectorstores import MongoDBAtlasVectorSearch | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain.chains import RetrievalQA | |
| from langchain_openai import ChatOpenAI | |
| import logging | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| INDEX_NAME = "vector_index" | |
| DATABASE_NAME = "scraped_data_db" | |
| def mongo_rag_tool(query: str, collection_name: str) -> str: | |
| """ | |
| This function is used to retrieve documents from a MongoDB database and then use the RAG model to answer the query. | |
| The documents that are most semantically close to the query are returned. | |
| args: | |
| query: str: The query that you want to use to retrieve documents | |
| collection_name: str: The name of the collection in the MongoDB database | |
| returns: | |
| str: The answer to the query | |
| """ | |
| try: | |
| #collection_name = os.getenv("MONGODB_COLLECTION_NAME") | |
| # Connect to the MongoDB database | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key, disallowed_special=(), model="text-embedding-3-small") | |
| uri = os.getenv("MONGO_CONNECTION_STRING") | |
| logging.info("Creating the mongo vector search object") | |
| vector_search = MongoDBAtlasVectorSearch.from_connection_string( | |
| uri, | |
| DATABASE_NAME + "." + collection_name, | |
| embeddings, | |
| index_name=INDEX_NAME, | |
| ) | |
| logging.info("Retrieving the documents and answering the query") | |
| # Retrieve the documents that are most semantically close to the query, exclude ones that are less similar than the threshold | |
| post_filter = [{"$project": {"_id": 0,"text": 1,"source": 1,"score":1,"embedding":1}}] | |
| qa_retriever = vector_search.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={"k": 10, 'fetch_k':100, "post_filter_pipeline": post_filter}, | |
| ) | |
| prompt_template = """Use the following pieces of context to answer the question at the end. | |
| If you don't know the answer, just say that you don't know, don't try to make up an answer. | |
| If you know the answer give a comprehensive, detailed and insightful answer. | |
| {context} | |
| Question: {question} | |
| """ | |
| PROMPT = PromptTemplate( | |
| template=prompt_template, input_variables=["context", "question"] | |
| ) | |
| qa = RetrievalQA.from_chain_type( | |
| llm=ChatOpenAI(api_key=openai_api_key, model="gpt-4o", temperature=0.2), | |
| chain_type="stuff", | |
| retriever=qa_retriever, | |
| return_source_documents=True, | |
| chain_type_kwargs={"prompt": PROMPT}, | |
| ) | |
| docs = qa.invoke({"query": query}) | |
| if docs: | |
| logging.info("Saving the retrieved documents") | |
| sources = docs["source_documents"] | |
| source_list = [{"content":result.page_content, "source":result.metadata.get("source", '')} for result in sources] | |
| formatted_sources = "\n".join([f"Content: {source['content']}\nSource: {source['source']}\n" for source in source_list]) | |
| return docs["result"], formatted_sources | |
| except Exception as e: | |
| logging.error(f"An error occurred: {str(e)}") | |
| return f"An error occurred: {str(e)}", "An error occurred: {str(e)}" | |
| #mongo_rag_tool("What do people think about caterpillar vision link fleet management app") |