whymath commited on
Commit
40e1b4d
·
1 Parent(s): 7010433

Creating base and RAG chains, adding upload and switch default buttons

Browse files
Files changed (2) hide show
  1. app.py +49 -42
  2. utils.py +23 -78
app.py CHANGED
@@ -2,30 +2,36 @@
2
  import chainlit as cl
3
  from dotenv import load_dotenv
4
  import utils
 
5
 
6
 
7
  load_dotenv()
8
 
9
 
10
- start_msg = "Teach2Learn Virtual Student by Jerry Chiang and Yohan Mathew\n\nYou can choose to upload a PDF, or just start chatting"
11
-
12
- # Create the RAQA chain and store it in the user session
13
- raqa_chain = utils.create_raqa_chain_from_docs()
 
 
 
 
14
 
15
 
16
  @cl.on_chat_start
17
  async def start_chat():
18
- # # Create the RAQA chain and store it in the user session
19
- # raqa_chain = utils.create_raqa_chain_from_docs()
20
- # settings = {
21
- # "chain": raqa_chain
22
- # }
23
- # cl.user_session.set("settings", settings)
24
  print("Chat started")
25
 
26
- # Send a welcome message with an action button
 
 
 
 
 
 
27
  actions = [
28
- cl.Action(name="upload_pdf", value="upload_pdf_value", label="Upload a PDF", description="Upload a PDF")
 
29
  ]
30
  await cl.Message(content=start_msg, actions=actions).send()
31
 
@@ -34,58 +40,59 @@ async def start_chat():
34
  async def main(message: cl.Message):
35
  # Print the message content
36
  user_query = message.content
37
- print('\nuser_query =', user_query)
38
-
39
- # Get the chain from the user session
40
- try:
41
- settings = cl.user_session.get("settings")
42
- raqa_chain_upload = settings["raqa_chain_upload"]
43
- except Exception as e:
44
- print("Error fetching chain from session, defaulting to base chain", e)
45
- raqa_chain_upload = None
46
 
47
  # Generate the response from the chain
48
- if raqa_chain_upload:
49
- print("\nUsing UPLOAD chain to answer query", user_query)
50
- query_response = raqa_chain_upload.invoke({"question" : user_query})
 
 
51
  else:
52
- print("\nUsing DEFAULT chain to answer query", user_query)
53
- query_response = raqa_chain.invoke({"question" : user_query})
54
- query_answer = query_response["response"].content
55
- print('query_answer =', query_answer, '\n')
56
 
57
- # Create and send the message stream
 
58
  msg = cl.Message(content=query_answer)
59
  await msg.send()
60
 
61
 
62
  @cl.action_callback("upload_pdf")
63
  async def upload_pdf_fn(action: cl.Action):
64
- print("\nThe user clicked on an action button!")
65
-
66
- files = None
67
 
68
  # Wait for the user to upload a file
 
69
  while files == None:
70
  files = await cl.AskFileMessage(
71
- content="Processing your file",
72
  accept=["application/pdf"],
73
  max_size_mb=20,
74
  timeout=180,
75
  ).send()
76
-
77
  file_uploaded = files[0]
78
  print("\nUploaded file:", file_uploaded, "\n")
79
 
80
- # Create the RAQA chain and store it in the user session
81
- filepath_uploaded = file_uploaded.path
82
- filename_uploaded = file_uploaded.name
83
- raqa_chain_upload = utils.create_raqa_chain_from_file(filepath_uploaded, filename_uploaded)
 
 
84
 
85
- settings = {
86
- "raqa_chain_upload": raqa_chain_upload
87
- }
 
 
 
 
 
 
 
88
  cl.user_session.set("settings", settings)
89
 
90
- msg = cl.Message(content="Thank you for uploading!")
91
  await msg.send()
 
2
  import chainlit as cl
3
  from dotenv import load_dotenv
4
  import utils
5
+ from langchain_openai import ChatOpenAI
6
 
7
 
8
  load_dotenv()
9
 
10
 
11
+ start_msg = "Hello! I'm Teach2Learn VirtualStudent, a virtual student peer by Jerry Chiang and Yohan Mathew\n\nYou can choose to upload a PDF, or just start chatting\n"
12
+ base_instructions = """
13
+ Assume you are a virtual student being taught by the user. Your goal is to ensure that the user understands the concept they are explaining.
14
+ You should always first let the user know if they are correct or not, and then ask them questions to help them learn by teaching rather than explaining things to them.
15
+ If they ask for feedback, you should provide constructive feedback on the whole conversation instead of asking another question.
16
+ """
17
+ openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo")
18
+ base_chain = utils.create_base_chain(openai_chat_model, base_instructions)
19
 
20
 
21
  @cl.on_chat_start
22
  async def start_chat():
 
 
 
 
 
 
23
  print("Chat started")
24
 
25
+ # Set the user session settings
26
+ settings = {
27
+ "rag_chain_available": False
28
+ }
29
+ cl.user_session.set("settings", settings)
30
+
31
+ # Send a welcome message with action buttons
32
  actions = [
33
+ cl.Action(name="upload_pdf", value="upload_pdf_value", label="Upload a PDF", description="Upload a PDF"),
34
+ cl.Action(name="switch_default", value="switch_default_value", label="Switch back to default mode", description="Switch back to default mode")
35
  ]
36
  await cl.Message(content=start_msg, actions=actions).send()
37
 
 
40
  async def main(message: cl.Message):
41
  # Print the message content
42
  user_query = message.content
43
+ settings = cl.user_session.get("settings")
 
 
 
 
 
 
 
 
44
 
45
  # Generate the response from the chain
46
+ if settings["rag_chain_available"]:
47
+ print("\nUsing RAG chain to answer query", user_query)
48
+ rag_chain = settings["rag_chain"]
49
+ query_response = rag_chain.invoke({"question" : user_query})
50
+ query_answer = query_response["response"].content
51
  else:
52
+ print("\nUsing base chain to answer query", user_query)
53
+ query_response = base_chain.invoke({"question" : user_query})
54
+ query_answer = query_response.content
 
55
 
56
+ # Create and send the message stream
57
+ print('query_answer =', query_answer, '\n')
58
  msg = cl.Message(content=query_answer)
59
  await msg.send()
60
 
61
 
62
  @cl.action_callback("upload_pdf")
63
  async def upload_pdf_fn(action: cl.Action):
64
+ print("\nRunning PDF upload and RAG chain creation")
 
 
65
 
66
  # Wait for the user to upload a file
67
+ files = None
68
  while files == None:
69
  files = await cl.AskFileMessage(
70
+ content="Processing your file...",
71
  accept=["application/pdf"],
72
  max_size_mb=20,
73
  timeout=180,
74
  ).send()
 
75
  file_uploaded = files[0]
76
  print("\nUploaded file:", file_uploaded, "\n")
77
 
78
+ # Create the RAG chain and store it in the user session
79
+ rag_chain = utils.create_rag_chain_from_file(openai_chat_model, base_instructions, file_uploaded.path, file_uploaded.name)
80
+ settings = cl.user_session.get("settings")
81
+ settings["rag_chain"] = rag_chain
82
+ settings["rag_chain_available"] = True
83
+ cl.user_session.set("settings", settings)
84
 
85
+ msg = cl.Message(content="Ready to discuss the uploaded PDF file!")
86
+ await msg.send()
87
+
88
+
89
+ @cl.action_callback("switch_default")
90
+ async def switch_default_fn(action: cl.Action):
91
+ print("\nSwitching back to default base chain")
92
+
93
+ settings = cl.user_session.get("settings")
94
+ settings["rag_chain_available"] = False
95
  cl.user_session.set("settings", settings)
96
 
97
+ msg = cl.Message(content="Okay, I'm back to answering general questions. What would you like to try teaching me next?")
98
  await msg.send()
utils.py CHANGED
@@ -4,9 +4,7 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain_openai.embeddings import OpenAIEmbeddings
5
  from langchain_community.vectorstores import Qdrant
6
  from langchain_core.prompts import ChatPromptTemplate
7
- from langchain_openai import ChatOpenAI
8
  from operator import itemgetter
9
- # from langchain.schema.output_parser import StrOutputParser
10
  from langchain.schema.runnable import RunnablePassthrough
11
 
12
 
@@ -23,77 +21,28 @@ def chunk_documents(docs, tiktoken_len):
23
  chunk_overlap = 0,
24
  length_function = tiktoken_len,
25
  )
26
-
27
  split_chunks = text_splitter.split_documents(docs)
28
-
29
  print('len(split_chunks) =', len(split_chunks))
30
-
31
  return split_chunks
32
 
33
 
34
- def create_raqa_chain_from_docs():
35
- # Load the documents from a PDF file using PyMuPDFLoader
36
- # docs = PyMuPDFLoader("data/c7318154-f6ae-4866-89fa-f0c589f2ee3d.pdf").load()
37
- docs = PyMuPDFLoader("https://d18rn0p25nwr6d.cloudfront.net/CIK-0001326801/c7318154-f6ae-4866-89fa-f0c589f2ee3d.pdf").load()
38
-
39
- # Print the number of loaded documents
40
- print("Loaded", len(docs), "documents")
41
-
42
- # Print the first document
43
- print(docs[0])
44
-
45
- # Split the documents into chunks based on their length
46
- split_chunks = chunk_documents(docs, tiktoken_len)
47
-
48
- # Create an instance of the OpenAIEmbeddings model for text embeddings
49
- embedding_model = OpenAIEmbeddings(model="text-embedding-3-small")
50
-
51
- # Create a Qdrant vector store from the split chunks
52
- qdrant_vectorstore = Qdrant.from_documents(
53
- split_chunks,
54
- embedding_model,
55
- location=":memory:",
56
- collection_name="Meta 10-k Filings",
57
- )
58
-
59
- # Create a retriever from the Qdrant vector store
60
- qdrant_retriever = qdrant_vectorstore.as_retriever()
61
-
62
- # Define the RAG prompt template
63
- RAG_PROMPT = """
64
- CONTEXT:
65
- {context}
66
-
67
- QUERY:
68
- {question}
69
-
70
- Use the provided context to answer the provided user query. Only use the provided context to answer the query. If you do not know the answer, respond with "I don't know".
71
- """
72
-
73
- # Create a ChatPromptTemplate instance from the RAG prompt template
74
- rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
75
-
76
- # Create an instance of the ChatOpenAI model
77
- openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo")
78
-
79
- # Define the retrieval augmented QA chain
80
- retrieval_augmented_qa_chain = (
81
- {"context": itemgetter("question") | qdrant_retriever, "question": itemgetter("question")}
82
- | RunnablePassthrough.assign(context=itemgetter("context"))
83
- | {"response": rag_prompt | openai_chat_model, "context": itemgetter("context")}
84
- )
85
- print("Created retrieval augmented QA chain from default PDF file")
86
-
87
- return retrieval_augmented_qa_chain
88
 
89
 
90
- def create_raqa_chain_from_file(filepath_uploaded, filename_uploaded):
91
 
92
- # # Load the documents from a PDF file using PyMuPDFLoader
93
- # docs = PyMuPDFLoader("https://d18rn0p25nwr6d.cloudfront.net/CIK-0001326801/c7318154-f6ae-4866-89fa-f0c589f2ee3d.pdf").load()
94
- docs = PyMuPDFLoader(filepath_uploaded).load()
95
  print("Loaded", len(docs), "documents")
96
- print(docs[0])
97
 
98
  # Create a Qdrant vector store from the split chunks and embedding model, and obtain its retriever
99
  split_chunks = chunk_documents(docs, tiktoken_len)
@@ -102,35 +51,31 @@ def create_raqa_chain_from_file(filepath_uploaded, filename_uploaded):
102
  split_chunks,
103
  embedding_model,
104
  location=":memory:",
105
- collection_name="LoadedPDF",
106
  )
107
  qdrant_retriever = qdrant_vectorstore.as_retriever()
 
108
 
109
  # Define the RAG prompt template
110
- # RAG_PROMPT = """
111
- # Assume you are a virtual student being taught by the user. You can ask clarifying questions to better understand the user's explanation. Your goal is to ensure that the user understands the concept they are explaining. You can also ask questions to help the user elaborate on their explanation. You can ask questions like "Can you explain that in simpler terms?" or "Can you provide an example?".
112
-
113
- # USER MESSAGE:
114
- # {question}
115
- # """
116
  RAG_PROMPT = """
117
- CONTEXT:
118
- {context}
119
 
120
  QUERY:
121
  {question}
122
 
123
- Use the provided context to answer the provided user query. Only use the provided context to answer the query. If you do not know the answer, respond with "I don't know".
 
124
  """
 
 
125
  rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
126
 
127
  # Create the retrieval augmented QA chain using the Qdrant retriever, RAG prompt, and OpenAI chat model
128
- openai_chat_model = ChatOpenAI(model="gpt-3.5-turbo")
129
- retrieval_augmented_qa_chain = (
130
  {"context": itemgetter("question") | qdrant_retriever, "question": itemgetter("question")}
131
  | RunnablePassthrough.assign(context=itemgetter("context"))
132
  | {"response": rag_prompt | openai_chat_model, "context": itemgetter("context")}
133
  )
134
- print("Created retrieval augmented QA chain from uploaded PDF file =", filename_uploaded, "\n")
135
 
136
- return retrieval_augmented_qa_chain
 
4
  from langchain_openai.embeddings import OpenAIEmbeddings
5
  from langchain_community.vectorstores import Qdrant
6
  from langchain_core.prompts import ChatPromptTemplate
 
7
  from operator import itemgetter
 
8
  from langchain.schema.runnable import RunnablePassthrough
9
 
10
 
 
21
  chunk_overlap = 0,
22
  length_function = tiktoken_len,
23
  )
 
24
  split_chunks = text_splitter.split_documents(docs)
 
25
  print('len(split_chunks) =', len(split_chunks))
 
26
  return split_chunks
27
 
28
 
29
+ def create_base_chain(openai_chat_model, base_instructions):
30
+ human_template = "{question}"
31
+ base_prompt = ChatPromptTemplate.from_messages([
32
+ ("system", base_instructions),
33
+ ("human", human_template)
34
+ ])
35
+ base_chain = base_prompt | openai_chat_model
36
+ print("Created base chain\n")
37
+ return base_chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
+ def create_rag_chain_from_file(openai_chat_model, base_instructions, file_path, file_name):
41
 
42
+ # Load the documents from a PDF file using PyMuPDFLoader
43
+ docs = PyMuPDFLoader(file_path).load()
 
44
  print("Loaded", len(docs), "documents")
45
+ print("First document:\n", docs[0], "\n")
46
 
47
  # Create a Qdrant vector store from the split chunks and embedding model, and obtain its retriever
48
  split_chunks = chunk_documents(docs, tiktoken_len)
 
51
  split_chunks,
52
  embedding_model,
53
  location=":memory:",
54
+ collection_name=file_name,
55
  )
56
  qdrant_retriever = qdrant_vectorstore.as_retriever()
57
+ print("Created Qdrant vector store from uploaded PDF file =", file_name)
58
 
59
  # Define the RAG prompt template
 
 
 
 
 
 
60
  RAG_PROMPT = """
61
+ Use the provided context while replying to the user query. Only use the provided context to answer the query.
 
62
 
63
  QUERY:
64
  {question}
65
 
66
+ CONTEXT:
67
+ {context}
68
  """
69
+ RAG_PROMPT = base_instructions + RAG_PROMPT
70
+ print("RAG prompt template =", RAG_PROMPT)
71
  rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT)
72
 
73
  # Create the retrieval augmented QA chain using the Qdrant retriever, RAG prompt, and OpenAI chat model
74
+ rag_chain = (
 
75
  {"context": itemgetter("question") | qdrant_retriever, "question": itemgetter("question")}
76
  | RunnablePassthrough.assign(context=itemgetter("context"))
77
  | {"response": rag_prompt | openai_chat_model, "context": itemgetter("context")}
78
  )
79
+ print("Created RAG chain from uploaded PDF file =", file_name, "\n")
80
 
81
+ return rag_chain