Spaces:
Sleeping
Sleeping
| import re | |
| from langchain_openai import OpenAIEmbeddings | |
| from langchain_openai import ChatOpenAI | |
| from langchain_openai.embeddings import OpenAIEmbeddings | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.schema import StrOutputParser | |
| from langchain_community.document_loaders import PyMuPDFLoader | |
| from langchain_community.vectorstores import Qdrant | |
| from langchain_core.runnables import RunnablePassthrough, RunnableParallel | |
| from langchain_core.documents import Document | |
| from operator import itemgetter | |
| import os | |
| from dotenv import load_dotenv | |
| import chainlit as cl | |
| load_dotenv() | |
| ai_framework_document = PyMuPDFLoader(file_path="https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf").load() | |
| ai_blueprint_document = PyMuPDFLoader(file_path="https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf").load() | |
| def metadata_generator(document, name): | |
| fixed_text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=500, | |
| chunk_overlap=100, | |
| separators=["\n\n", "\n", ".", "!", "?"] | |
| ) | |
| collection = fixed_text_splitter.split_documents(document) | |
| for doc in collection: | |
| doc.metadata["source"] = name | |
| return collection | |
| recursive_framework_document = metadata_generator(ai_framework_document, "AI Framework") | |
| recursive_blueprint_document = metadata_generator(ai_blueprint_document, "AI Blueprint") | |
| combined_documents = recursive_framework_document + recursive_blueprint_document | |
| embeddings = OpenAIEmbeddings(model="text-embedding-3-small") | |
| vectorstore = Qdrant.from_documents( | |
| documents=combined_documents, | |
| embedding=embeddings, | |
| location=":memory:", | |
| collection_name="ai_policy" | |
| ) | |
| alt_retriever = vectorstore.as_retriever() | |
| ## Generation LLM | |
| llm = ChatOpenAI(model="gpt-4o-mini") | |
| RAG_PROMPT = """\ | |
| You are an AI Policy Expert. | |
| Given a provided context and question, you must answer the question based only on context. | |
| Think through your answer carefully and step by step. | |
| Context: {context} | |
| Question: {question} | |
| """ | |
| rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) | |
| retrieval_augmented_qa_chain = ( | |
| # INVOKE CHAIN WITH: {"question" : "<<SOME USER QUESTION>>"} | |
| # "question" : populated by getting the value of the "question" key | |
| # "context" : populated by getting the value of the "question" key and chaining it into the base_retriever | |
| {"context": itemgetter("question") | alt_retriever, "question": itemgetter("question")} | |
| # "context" : is assigned to a RunnablePassthrough object (will not be called or considered in the next step) | |
| # by getting the value of the "context" key from the previous step | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| # "response" : the "context" and "question" values are used to format our prompt object and then piped | |
| # into the LLM and stored in a key called "response" | |
| # "context" : populated by getting the value of the "context" key from the previous step | |
| | {"response": rag_prompt | llm, "context": itemgetter("context")} | |
| ) | |
| #alt_rag_chain.invoke({"question" : "What is the AI framework all about?"}) | |
| async def handle_message(message): | |
| try: | |
| # Process the incoming question using the RAG chain | |
| result = retrieval_augmented_qa_chain.invoke({"question": message.content}) | |
| # Create a new message for the response | |
| response_message = cl.Message(content=result["response"].content) | |
| # Send the response back to the user | |
| await response_message.send() | |
| except Exception as e: | |
| # Handle any exception and log it or send a response back to the user | |
| error_message = cl.Message(content=f"An error occurred: {str(e)}") | |
| await error_message.send() | |
| print(f"Error occurred: {e}") | |
| # Run the ChainLit server | |
| if __name__ == "__main__": | |
| try: | |
| cl.run() | |
| except Exception as e: | |
| print(f"Server error occurred: {e}") |