Jorge Londoño commited on
Commit
bcaa3c2
·
1 Parent(s): 8ea0e72

Updated project files

Browse files
Files changed (3) hide show
  1. app.py +16 -55
  2. assistant.py +60 -0
  3. correctiveRag.py +437 -0
app.py CHANGED
@@ -1,64 +1,25 @@
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
 
 
 
 
 
 
 
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
1
+ import uuid
2
  import gradio as gr
 
3
 
4
+ from assistant import voice_input, chat_response
 
 
 
5
 
6
 
 
 
 
 
 
 
 
 
 
7
 
8
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ with gr.Blocks() as demo:
11
+ session_id = str(uuid.uuid4()) # unique session ID
12
+ state = gr.State(value={'session_id': session_id})
13
+ chatbot = gr.Chatbot(type='messages')
14
+ with gr.Row():
15
+ txt = gr.Textbox(show_label=False, placeholder="Type your message here", container=False)
16
+ voice = gr.Audio(type="filepath", sources=["microphone"], label="Or speak your message")
17
+
18
+ txt_msg = txt.submit(chat_response, inputs=[txt,chatbot,state], outputs=[txt,chatbot])
19
 
20
+ voice_msg = voice.stop_recording(
21
+ voice_input, inputs=[voice, chatbot, state], outputs=[txt,chatbot]
22
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
+ demo.launch()
25
 
 
 
assistant.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logging.basicConfig(level=logging.WARNING)
3
+
4
+ import os
5
+ from langchain.schema import AIMessage, HumanMessage, SystemMessage
6
+
7
+ from dotenv import load_dotenv
8
+ load_dotenv(verbose=True)
9
+ assert os.getenv('GROQ_MODEL') is not None
10
+ assert os.getenv('GROQ_WHISPER_MODEL') is not None
11
+
12
+ import gradio as gr
13
+ from gradio import ChatMessage
14
+
15
+ from correctiveRag import app
16
+
17
+ logger = logging.getLogger(__name__) # Child logger for this module
18
+ logger.setLevel(logging.INFO)
19
+
20
+ # Groq - Audio transcription
21
+ from groq import Groq
22
+ transcription_client = Groq()
23
+
24
+
25
+ system_message = "You are a helpful assistant who provides consice answers to questions."
26
+
27
+ def chat_response(message: str, history: list[dict], state: dict):
28
+ logger.debug(f"session_id = {state['session_id']}")
29
+ config = {"configurable": { "thread_id": state['session_id'] }}
30
+
31
+ if message is not None:
32
+ history.append( ChatMessage(role='user', content=message) )
33
+ response = app.invoke({"messages": [HumanMessage(content=message)]}, config)
34
+ answer = ''.join(response['generation'])
35
+ history.append(ChatMessage(role='assistant', content=answer))
36
+ return "", history
37
+
38
+
39
+ def transcribe_audio(filename):
40
+ print('filename', filename, os.getenv('GROQ_WHISPER_MODEL'))
41
+ with open(filename, "rb") as audio_file:
42
+ transcription = transcription_client.audio.transcriptions.create(
43
+ file=(filename, audio_file.read()),
44
+ model=os.getenv('GROQ_WHISPER_MODEL'), # Required model to use for transcription
45
+ prompt="Preguntas sobre estructuras de datos y algoritmos.", # Optional
46
+ response_format="json", # Optional
47
+ language="es", # Optional
48
+ temperature=0.0 # Optional
49
+ )
50
+ return transcription.text
51
+
52
+
53
+
54
+ def voice_input(audio, history, state):
55
+ transcription = transcribe_audio(audio)
56
+ return chat_response(transcription, history, state)
57
+
58
+
59
+
60
+
correctiveRag.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import os
4
+ from pydantic import BaseModel, Field
5
+ from typing import Literal, List, Any, Annotated
6
+
7
+ from typing_extensions import TypedDict
8
+ from langchain.schema import Document
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_core.prompts.prompt import PromptTemplate
11
+ from langchain_core.messages import HumanMessage, AIMessage, AnyMessage
12
+ from langgraph.graph import END, StateGraph, MessagesState, START
13
+ from langgraph.graph.message import add_messages
14
+ from huggingface_hub import InferenceClient
15
+
16
+ from dotenv import load_dotenv
17
+ load_dotenv(verbose=True)
18
+ assert os.getenv("PINECONE_API_KEY") is not None
19
+ assert os.getenv("HUGGINGFACEHUB_EMBEDDINGS_MODEL") is not None
20
+
21
+
22
+ logger = logging.getLogger(__name__) # Child logger for this module
23
+ logger.setLevel(logging.INFO)
24
+ logger.info(f"""correctiveRag.py:Config:
25
+ GROQ_MODEL = {os.getenv('GROQ_MODEL')}
26
+ HUGGINGFACEHUB_EMBEDDINGS_MODEL = {os.getenv('HUGGINGFACEHUB_EMBEDDINGS_MODEL')}
27
+ PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")[:5]
28
+ """)
29
+
30
+ # Prepare the LLM
31
+ # from langchain_groq import ChatGroq
32
+ # assert os.getenv('GROQ_MODEL') is not None, "GROQ_MODEL not set"
33
+ # llm = ChatGroq(model_name=os.getenv('GROQ_MODEL'), temperature=0, verbose=True)
34
+
35
+ # from langchain_openai import ChatOpenAI
36
+ # assert os.getenv('OPENAI_MODEL_NAME') is not None, "GROQ_MODEL not set"
37
+ # llm = ChatOpenAI(model=os.getenv("OPENAI_MODEL_NAME"), temperature=0.1, verbose=True)
38
+
39
+ # Huggingface
40
+ llm = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
41
+
42
+
43
+ # Prepare the retriever
44
+ from langchain_huggingface import HuggingFaceEmbeddings
45
+ from langchain_pinecone import PineconeVectorStore
46
+ index_name, namespace = 'courses', 'dsa'
47
+ # Simple RAG
48
+ # embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL) # os.getenv("HUGGINGFACEHUB_EMBEDDINGS_MODEL")
49
+ # docsearch = PineconeVectorStore.from_existing_index(embedding=embeddings, index_name=index_name, namespace=namespace)
50
+ # retriever = docsearch.as_retriever(search_type="mmr", search_kwargs={ 'k': 5 })
51
+
52
+ # Large-Small RAG
53
+ def larger_from_nearby(vectorstore, doc: Document, range:int) -> Document:
54
+ """
55
+ Given a document, find the "parent" document as a range of chunks around the central chunk
56
+ """
57
+ filter0 = { "document" : doc.metadata['document'] }
58
+ filter1 = { "chunk": { "$gte" : doc.metadata['chunk']-range } }
59
+ filter2 = { "chunk": { "$lte" : doc.metadata['chunk']+range } }
60
+ and_filter = { "$and" : [ filter0, filter1, filter2 ] }
61
+ range_docs = vectorstore.similarity_search(query='', k=2*range+1, filter=and_filter)
62
+ content = ''
63
+ for doc in range_docs:
64
+ content += doc.page_content
65
+ full_document = Document(page_content=content, metadata=doc.metadata)
66
+ return full_document
67
+
68
+ def larger_retriever(vectorstore, query:str, topK:int):
69
+ RANGE=2 # -RANGE...+RANGE
70
+ logger.info(f'larger_retriever: with RANGE={RANGE}')
71
+ docs = vectorstore.similarity_search(query, k=topK)
72
+ larger_documents = list(map(lambda d: larger_from_nearby(vectorstore, d, RANGE), docs))
73
+ logger.info(f'larger_retriever: Found {len(larger_documents)} documents.')
74
+ return larger_documents
75
+
76
+ embeddings = HuggingFaceEmbeddings(model_name=os.getenv("HUGGINGFACEHUB_EMBEDDINGS_MODEL"))
77
+ vectorstore = PineconeVectorStore.from_existing_index(embedding=embeddings, index_name=index_name, namespace=namespace)
78
+ # docs = larger_retriever(vectorstore, query, 5)
79
+ retriever = lambda query: larger_retriever(vectorstore, query, 5) # TODO topK
80
+
81
+
82
+ # Classify question
83
+ class ClassifyQuestion(BaseModel):
84
+ """Binary score to decide if need to retrieve documents from the vectorstore about data structures and algorithms.
85
+ The binary_score is "yes" to indicate that document retrieval is needed, otherwise is "no"."""
86
+ binary_score: str = Field(description="If the question is about data structures and algorithms answer `yes`, otherwise answer `no`")
87
+ # justification: str = Field(description="Explained reasoning for giving the yes/no score")
88
+ # LLM with function call
89
+ structured_llm_grader = llm.with_structured_output(ClassifyQuestion)
90
+ # Prompt
91
+ system = """You are an expert at classifying user questions.
92
+ If the question are specific about data structures and algorithms, then answer `yes` to indicate that document retrieval is needed.
93
+ Otherwise, it is a question as a general question, answer `no`.
94
+ """
95
+ grade_prompt = ChatPromptTemplate.from_messages(
96
+ [
97
+ ("system", system),
98
+ ("human", "Question: {question}"),
99
+ ]
100
+ )
101
+ retriever_grader = grade_prompt | structured_llm_grader
102
+
103
+
104
+ # Retrieval grader
105
+ class GradeDocuments(BaseModel):
106
+ """Binary score for relevance check on retrieved documents."""
107
+ binary_score: str = Field(
108
+ description="Documents are relevant to the question, 'yes' or 'no'"
109
+ )
110
+ # LLM with function call
111
+ structured_llm_grader = llm.with_structured_output(GradeDocuments)
112
+ # Prompt
113
+ system = """You are a grader assessing relevance of a retrieved document to a user question.
114
+ If the document contains keyword(s) or semantic meaning related to the user question, grade it as relevant.
115
+ It does not need to be a stringent test. The goal is to filter out erroneous retrievals.
116
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
117
+ grade_prompt = ChatPromptTemplate.from_messages(
118
+ [
119
+ ("system", system),
120
+ ("human", "Retrieved document: \n\n {document} \n\n User question: {question}"),
121
+ ]
122
+ )
123
+ retrieval_grader = grade_prompt | structured_llm_grader
124
+
125
+
126
+ # Create the RAG chain
127
+ from langchain import hub
128
+ from langchain_core.output_parsers import StrOutputParser
129
+ # prompt = hub.pull("rlm/rag-prompt")
130
+ # print('----', prompt, '---')
131
+ template = """You are an assistant for question-answering tasks.
132
+ Use the following pieces of retrieved context to answer the question.
133
+ If you don't know the answer, just say that you don't know.
134
+ Please keep the answer concise and to the point.
135
+
136
+ Context: {context}
137
+
138
+ Question: {question}
139
+
140
+ Answer:
141
+ """
142
+ prompt_template = PromptTemplate.from_template(template=template)
143
+ rag_chain = prompt_template | llm | StrOutputParser()
144
+
145
+
146
+ # Question rewriter
147
+ system = """You a question re-writer that converts an input question to a better version that is optimized
148
+ for vectorstore retrieval. Look at the input and try to reason about the underlying semantic intent / meaning.
149
+ Return only the re-written question. Do not return anything else.
150
+ """
151
+ re_write_prompt = ChatPromptTemplate.from_messages(
152
+ [
153
+ ("system", system),
154
+ ("human", "Here is the initial question: \n\n {question} \n Formulate an improved question."),
155
+ ]
156
+ )
157
+ question_rewriter = re_write_prompt | llm | StrOutputParser()
158
+
159
+
160
+ # Web search tool
161
+ from langchain_community.tools.tavily_search import TavilySearchResults
162
+ web_search_tool = TavilySearchResults(k=3)
163
+
164
+
165
+ # Define the workflow Graph
166
+
167
+ class GraphState(TypedDict):
168
+ """
169
+ Represents the state of our graph.
170
+
171
+ Attributes:
172
+ messages: conversation history
173
+ generation: LLM generation
174
+ web_search: whether to add search
175
+ documents: list of documents
176
+ question: the last user question
177
+ """
178
+ messages: Annotated[list[AnyMessage], add_messages]
179
+ generation: str
180
+ web_search: str
181
+ documents: List[str]
182
+ question: str
183
+
184
+ def chatbot(state: GraphState):
185
+ logger.info("---GENERATE (no context)---")
186
+ logger.info(state)
187
+ chain = llm | StrOutputParser()
188
+ generation = chain.invoke(state["messages"])
189
+ logger.info(generation)
190
+ return { "messages": [AIMessage(content=generation)], "generation": generation }
191
+
192
+ def retrieve(state):
193
+ """
194
+ Retrieve documents
195
+
196
+ Args:
197
+ state (dict): The current graph state
198
+
199
+ Returns:
200
+ state (dict): New key added to state, documents, that contains retrieved documents
201
+ """
202
+ logger.info("---RETRIEVE---")
203
+ question = state['messages'][-1].content # Last Human message
204
+ logger.info(f'question: {question}')
205
+ # Retrieval
206
+ # documents = retriever.invoke(question)
207
+ documents = retriever(question) # Large-small retriever
208
+ # logger.debug(documents)
209
+ logger.info([ (doc.metadata['id'], doc.page_content[:20])for doc in documents ])
210
+ return {"documents": documents, "question": question}
211
+
212
+
213
+ def generate_with_context(state):
214
+ """
215
+ Generate answer
216
+
217
+ Args:
218
+ state (dict): The current graph state
219
+
220
+ Returns:
221
+ state (dict): New key added to state, generation, that contains LLM generation
222
+ """
223
+ logger.debug("---GENERATE WITH CONTEXT---")
224
+ logger.debug(f'state: {state}')
225
+ question = state["question"]
226
+ documents = state["documents"]
227
+ # RAG generation
228
+ generation = rag_chain.invoke({"context": documents, "question": question})
229
+ logger.debug(generation)
230
+ return {"documents": documents, "question": question, "generation": generation}
231
+
232
+
233
+ def web_search(state):
234
+ """
235
+ Web search based on the re-phrased question.
236
+
237
+ Args:
238
+ state (dict): The current graph state
239
+
240
+ Returns:
241
+ state (dict): Updates documents key with appended web results
242
+ """
243
+ logger.debug("---WEB SEARCH---")
244
+ question = state["question"]
245
+ documents = state["documents"]
246
+ # Web search
247
+ logger.debug(f'question: {question}')
248
+ docs = web_search_tool.invoke({"query": question}) # Returns str if error
249
+ logger.debug(f'type(docs) = {type(docs)}')
250
+ logger.debug(docs)
251
+ web_results = "\n".join([d["content"] for d in docs])
252
+ web_results = Document(page_content=web_results)
253
+ documents.append(web_results)
254
+ return {"documents": web_results, "question": question}
255
+
256
+
257
+ def grade_documents(state):
258
+ """
259
+ Determines whether the retrieved documents are relevant to the question.
260
+
261
+ Args:
262
+ state (dict): The current graph state
263
+
264
+ Returns:
265
+ state (dict): Updates documents key with only filtered relevant documents
266
+ """
267
+ logger.debug("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
268
+ question = state["question"]
269
+ documents = state["documents"]
270
+ # Score each doc
271
+ filtered_docs = []
272
+ web_search = "No"
273
+ for d in documents:
274
+ score = retrieval_grader.invoke({"question": question, "document": d.page_content})
275
+ grade = score.binary_score
276
+ if grade == "yes":
277
+ logger.debug("---GRADE: DOCUMENT RELEVANT---")
278
+ filtered_docs.append(d)
279
+ else:
280
+ logger.debug("---GRADE: DOCUMENT NOT RELEVANT---")
281
+ web_search = "Yes"
282
+ continue
283
+ return {"documents": filtered_docs, "question": question, "web_search": web_search}
284
+
285
+
286
+ def transform_query(state):
287
+ """
288
+ Transform the query to produce a better question.
289
+
290
+ Args:
291
+ state (dict): The current graph state
292
+
293
+ Returns:
294
+ state (dict): Updates question key with a re-phrased question
295
+ """
296
+ logger.debug("---TRANSFORM QUERY---")
297
+ question = state["question"]
298
+ documents = state["documents"]
299
+ # Re-write question
300
+ better_question = question_rewriter.invoke({"question": question})
301
+ return {"documents": documents, "question": better_question}
302
+
303
+
304
+
305
+ ### Edges ###
306
+
307
+ # For conditional edges
308
+ def decide_to_retrieve(state):
309
+ """
310
+ Determines whether to retrieve a context for answering a question.
311
+
312
+ Args:
313
+ state (dict): The current graph state
314
+
315
+ Returns:
316
+ str: Binary decision for next node to call
317
+ """
318
+ logger.debug("---ASSESS NEED FOR RETRIEVAL---")
319
+ # logger.debug(state)
320
+ question = state['messages'][-1].content # Last Human message
321
+ logger.debug(question)
322
+ response = retriever_grader.invoke({ 'question': question })
323
+ logger.debug(response)
324
+ logger.debug(response.binary_score)
325
+
326
+ if response.binary_score == "yes":
327
+ # All documents have been filtered check_relevance
328
+ # We will re-generate a new query
329
+ logger.debug("---DECISION: RETRIEVE DOCUMENTS---")
330
+ return "retrieve"
331
+ else:
332
+ # We have relevant documents, so generate answer
333
+ logger.debug("---DECISION: GENERAL QUESTION, NO RETRIEVAL---")
334
+ # state['question'] = question
335
+ return "chatbot"
336
+
337
+
338
+ def decide_to_generate(state):
339
+ """
340
+ Determines whether to generate an answer, or re-generate a question.
341
+
342
+ Args:
343
+ state (dict): The current graph state
344
+
345
+ Returns:
346
+ str: Binary decision for next node to call
347
+ """
348
+ logger.debug("---ASSESS GRADED DOCUMENTS---")
349
+ state["question"]
350
+ web_search = state["web_search"]
351
+ state["documents"]
352
+
353
+ if web_search == "Yes":
354
+ # All documents have been filtered check_relevance
355
+ # We will re-generate a new query
356
+ logger.debug(
357
+ "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, TRANSFORM QUERY---"
358
+ )
359
+ return "transform_query"
360
+ else:
361
+ # We have relevant documents, so generate answer
362
+ logger.debug("---DECISION: GENERATE---")
363
+ return "generate_with_context"
364
+
365
+
366
+ # Prepare and compile the Graph
367
+ workflow = StateGraph(GraphState)
368
+
369
+ # Define the nodes
370
+ workflow.add_node("chatbot", chatbot) # retrieve
371
+ workflow.add_node("retrieve", retrieve) # retrieve
372
+ workflow.add_node("grade_documents", grade_documents) # grade documents
373
+ workflow.add_node("generate_with_context", generate_with_context) # generate
374
+ workflow.add_node("transform_query", transform_query) # transform_query
375
+ workflow.add_node("web_search_node", web_search) # web search
376
+
377
+ # Build graph
378
+ # workflow.add_edge(START, "retrieve")
379
+ workflow.add_conditional_edges(
380
+ START,
381
+ decide_to_retrieve
382
+ )
383
+ workflow.add_edge("retrieve", "grade_documents")
384
+ workflow.add_conditional_edges(
385
+ "grade_documents",
386
+ decide_to_generate,
387
+ {
388
+ "transform_query": "transform_query",
389
+ "generate_with_context": "generate_with_context",
390
+ },
391
+ )
392
+ workflow.add_edge("transform_query", "web_search_node")
393
+ workflow.add_edge("web_search_node", "generate_with_context")
394
+ workflow.add_edge("generate_with_context", "chatbot")
395
+ workflow.add_edge("chatbot", END)
396
+
397
+ # Compile
398
+ from langgraph.checkpoint.memory import MemorySaver
399
+ memory = MemorySaver()
400
+ app = workflow.compile(checkpointer=memory, debug=False)
401
+
402
+
403
+ if __name__ == "__main__":
404
+ # Use the graph
405
+ from pprint import pprint
406
+
407
+ # print(retriever.invoke("What is an algorithm?"))
408
+
409
+ config = {"configurable": {"thread_id": "abc123"}}
410
+
411
+
412
+ def query_graph(question:str):
413
+ inputs = { "question": question }
414
+ messages = [HumanMessage(inputs['question'])]
415
+ response = app.invoke({"messages": messages}, config)
416
+ # print('TYPE >>', type(response)) # langgraph.pregel.io.AddableValuesDict
417
+ return response
418
+
419
+ def print_generation(response:str):
420
+ # pprint(type(response['generation'])) # str
421
+ # pprint(response['messages'])
422
+ pprint(response['generation']) # AIMessage (no context)
423
+
424
+ # question = "Cual es el orden de ingreso y egresos de elementos en un Queue?"
425
+ # pprint(query_graph(question))
426
+
427
+ # print_generation(query_graph("Hi, my name is George and I would like to learn about algorithms"))
428
+ # print_generation(query_graph("Do you remember my name? What algorithm would you use to reverse the letters in my name?"))
429
+ # print_generation(query_graph("Que es un algoritmo?"))
430
+ # print_generation(query_graph("Que es una heuristica?"))
431
+ # print_generation(query_graph("Que se entiende por orden de crecimiento de un algoritmo?"))
432
+ # print_generation(query_graph("Que es la función tilde?"))
433
+ # print_generation(query_graph("Cuál es la diferencia entre función tilde y orden de crecimiento?"))
434
+
435
+ # In stream mode, returns the full 'chatbot' message
436
+ for x in app.stream({"messages": "What is the answer to the question of everything?"}, config):
437
+ print(x)