Spaces:
Sleeping
Sleeping
| import chainlit as cl | |
| import os | |
| from classes.app_state import AppState | |
| from classes.model_run_state import ModelRunState | |
| from dotenv import load_dotenv | |
| from langchain.schema.runnable import RunnablePassthrough | |
| from langchain_openai import ChatOpenAI | |
| from langchain_openai.embeddings import OpenAIEmbeddings | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from operator import itemgetter | |
| from utilities.doc_utilities import get_documents | |
| from utilities.templates import get_qa_prompt | |
| from utilities.vector_utilities import create_vector_store | |
| document_urls = [ | |
| "https://www.whitehouse.gov/wp-content/uploads/2022/10/Blueprint-for-an-AI-Bill-of-Rights.pdf", | |
| "https://nvlpubs.nist.gov/nistpubs/ai/NIST.AI.600-1.pdf", | |
| ] | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| # Get the OpenAI API key from environment variables | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| # Setup our state and read the documents | |
| app_state = AppState() | |
| app_state.set_debug(False) | |
| app_state.set_document_urls(document_urls) | |
| get_documents(app_state) | |
| # set up this model run | |
| chainlit_state = ModelRunState() | |
| chainlit_state.name = "Chainlit" | |
| chainlit_state.qa_model_name = "gpt-4o-mini" | |
| chainlit_state.qa_model = ChatOpenAI(model=chainlit_state.qa_model_name, openai_api_key=openai_api_key) | |
| hf_username = "rchrdgwr" | |
| hf_repo_name = "finetuned-arctic-model" | |
| finetuned_model_name = f"{hf_username}/{hf_repo_name}" | |
| chainlit_state.embedding_model_name = finetuned_model_name | |
| chainlit_state.embedding_model = HuggingFaceEmbeddings(model_name=chainlit_state.embedding_model_name) | |
| chainlit_state.chunk_size = 1000 | |
| chainlit_state.chunk_overlap = 100 | |
| create_vector_store(app_state, chainlit_state ) | |
| chat_prompt = get_qa_prompt() | |
| # create the chain | |
| retrieval_augmented_qa_chain = ( | |
| {"context": itemgetter("question") | chainlit_state.retriever, "question": itemgetter("question")} | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": chat_prompt | chainlit_state.qa_model, "context": itemgetter("context")} | |
| ) | |
| opening_content = """ | |
| Welcome! | |
| I am AI Mentor - your guide to understanding the evolving AI industry. | |
| My goal is to help you learn how to think about building ethical and useful applications. | |
| I can answer your questions on AI based on the following 2 documents: | |
| - Blueprint for an AI Bill of Rights by the Whitehouse Office of Science and Technology Policy | |
| - Artificial Intelligence Risk Management Framework: Generative Artificial Intelligence Profile | |
| What would you like to learn about AI today? | |
| """ | |
| async def on_chat_start(): | |
| await cl.Message(content=opening_content).send() | |
| async def main(message): | |
| # formatted_prompt = prompt.format(question=message.content) | |
| # Call the LLM with the formatted prompt | |
| # response = llm.invoke(formatted_prompt) | |
| # | |
| MAX_PREVIEW_LENGTH = 100 | |
| response = retrieval_augmented_qa_chain.invoke({"question" : message.content }) | |
| answer_content = response["response"].content | |
| msg = cl.Message(content="") | |
| for i in range(0, len(answer_content), 50): # Adjust chunk size (e.g., 50 characters) | |
| chunk = answer_content[i:i+50] | |
| await msg.stream_token(chunk) | |
| # Send the response back to the user | |
| # await msg.send() | |
| context_documents = response["context"] | |
| # num_contexts = len(context_documents) | |
| # context_msg = f"Number of found context: {num_contexts}" | |
| # await cl.Message(content=context_msg).send() | |
| chunk_string = "Sources: " | |
| for doc in context_documents: | |
| document_title = doc.metadata.get("source", "Unknown Document") | |
| chunk_number = doc.metadata.get("chunk_number", "Unknown Chunk") | |
| if document_title == "": | |
| doc_string = "BOR" | |
| else: | |
| doc_string = "RMF" | |
| chunk_string = chunk_string + " " + doc_string + "-" + str(chunk_number) | |
| await cl.Message( | |
| content=f"{chunk_string}", | |
| ).send() | |
| # document_context = doc.page_content.strip() | |
| # truncated_context = document_context[:MAX_PREVIEW_LENGTH] + ("..." if len(document_context) > MAX_PREVIEW_LENGTH else "") | |
| # print("----------------------------------------") | |
| # print(truncated_context) | |
| # await cl.Message( | |
| # content=f"**{document_title} ( Chunk: {chunk_number})**", | |
| # elements=[ | |
| # cl.Text(content=truncated_context, display="inline") | |
| # ] | |
| # ).send() |