trident-10 commited on
Commit
d243026
·
verified ·
1 Parent(s): eede7ff

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +484 -0
  2. requirements.txt +26 -0
app.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ###<u> **GENERATIVE AI & LLM PROGRAMMING ASSIGNMENT # 4** </u>
3
+ * **NAME = HASSAN JAVAID**
4
+ * **ROLL NO. = MSCS23001**
5
+ * **TASK = Implementation of Multi-Agentic Retreival Augmented Generation (RAG) for document and search related queries**
6
+ * **LLM used: CHATGROQ WITH RAG**
7
+
8
+ This file i.e. app.py is shared for deployment on Hugging Face Spaces. This file was submitted as part of course
9
+ CS-500 Generative AI & LLM conducted in ITU, Lahore during Fall-2024.g
10
+
11
+ Hugging Face Space Link:
12
+
13
+
14
+ GitHub Repo Link:
15
+
16
+
17
+ This file and relavant repos are the property of the author and is under MIT License. Give credit when sharing.
18
+ """
19
+
20
+
21
+ import os
22
+ import asyncio
23
+ import dotenv
24
+ import gradio as gr
25
+ from langchain.schema import HumanMessage
26
+ from langchain_core.prompts import ChatPromptTemplate
27
+ from langchain_core.tools import tool
28
+ from langchain_groq import ChatGroq
29
+ from langchain_community.tools.tavily_search import TavilySearchResults
30
+ from sklearn.metrics.pairwise import cosine_similarity
31
+ from langgraph.graph import MessagesState
32
+ from langchain.vectorstores import Pinecone
33
+ from langchain_huggingface import HuggingFaceEmbeddings
34
+ import pinecone
35
+ from langgraph.graph import StateGraph, START, END
36
+ from langgraph.types import Command
37
+ from typing import Literal
38
+ from typing_extensions import TypedDict
39
+ from langgraph.prebuilt import create_react_agent
40
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
41
+
42
+
43
+ # Load environment variables
44
+ dotenv.load_dotenv()
45
+
46
+ # Initialize Pinecone with API key and environment
47
+ pc = pinecone.Pinecone(
48
+ api_key=os.environ['PINECONE_API_KEY'],
49
+ environment=os.environ.get('PINECONE_ENVIRONMENT')
50
+ )
51
+
52
+ index_name = "gen-ai-hw4"
53
+
54
+ # Ensure the index exists
55
+ if index_name not in pc.list_indexes().names():
56
+ pc.create_index(
57
+ name=index_name,
58
+ dimension=384, # Dimension of the embedding model
59
+ metric='cosine'
60
+ )
61
+ index = pc.Index(index_name)
62
+
63
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
64
+
65
+ vector_store = Pinecone.from_existing_index(
66
+ index_name=index_name,
67
+ embedding=embedding_model,
68
+ text_key="text"
69
+ )
70
+
71
+
72
+ SYS_PROMPT = """
73
+
74
+ Based on the content of your PDF document, here's a prompt to gather information:
75
+
76
+ "Gather information from the Netsol investor relations report PDF document. Please extract the following data points:
77
+
78
+ 1. Financial Highlights:
79
+ * Revenue figures for the past two years
80
+ * Net income figures for the past two years
81
+ * Gross profit margin percentages for the past two years
82
+ * Total assets and liabilities figures for the past two years
83
+ 2. Board of Directors and Senior Management:
84
+ * Names and positions of the company's board of directors
85
+ * Names and positions of the company's senior management team (including the Chairman, CEO, CFO, etc.)
86
+ 3. Company Profile:
87
+ * Overview of the company's products/services
88
+ * Main business segments
89
+ * Mission and vision statements
90
+ * Brief history of the company
91
+ 4. Visualizations and Graphs:
92
+ * Identify any graphs or charts that show trends in revenue, net income, or other key financial metrics
93
+ * Extract any information from infographics or plots that provide insights into the company's performance or industry trends
94
+ 5. Financial Terms:
95
+ * Define and provide examples of key financial terms used throughout the report (e.g., EBITDA, ROCE, etc.)
96
+ 6. Images and Pictures:
97
+ * Identify the names and roles of the company's board of directors and senior management team mentioned in the report
98
+ * Describe any notable events or milestones mentioned in the report
99
+
100
+ Please organize the extracted information into clear and concise sections, and provide any additional context or clarifications where necessary."
101
+
102
+ """
103
+
104
+ # Agnetic Tools Definition
105
+ @tool
106
+ def multiply(a: int, b: int) -> int:
107
+ """Multiply a and b.
108
+
109
+ Args:
110
+ a: first int
111
+ b: second int
112
+ """
113
+ return a * b
114
+
115
+ @tool
116
+ def add(a: int, b: int) -> int:
117
+ """Add a and b.
118
+
119
+ Args:
120
+ a: first int
121
+ b: second int
122
+
123
+ Returns:
124
+ The sum of a and b.
125
+ """
126
+ return a + b
127
+
128
+ @tool
129
+ def subtract(a: int, b: int) -> int:
130
+ """Subtract b from a.
131
+
132
+ Args:
133
+ a: first int
134
+ b: second int
135
+
136
+ Returns:
137
+ The difference of a and b.
138
+ """
139
+ return a - b
140
+
141
+ @tool
142
+ def divide(a: int, b: int) -> float:
143
+ """Divide a by b.
144
+
145
+ Args:
146
+ a: numerator
147
+ b: denominator
148
+
149
+ Returns:
150
+ The division of a by b.
151
+
152
+ Raises:
153
+ ValueError: If b is zero.
154
+ """
155
+ if b == 0:
156
+ raise ValueError("Division by zero is not allowed.")
157
+ return a / b
158
+
159
+ @tool
160
+ def exponentiate(a: int, b: int) -> int:
161
+ """Raise a to the power of b.
162
+
163
+ Args:
164
+ a: base
165
+ b: exponent
166
+
167
+ Returns:
168
+ a raised to the power of b.
169
+ """
170
+ return a ** b
171
+
172
+ # Tavily search tool
173
+ @tool
174
+ def search_tool(query: str, max_results: int = 3) -> str:
175
+ """
176
+ Perform a search query using the Tavily search tool to retrieve information.
177
+
178
+ This function utilizes the Tavily search tool to perform a web search
179
+ for the given query and returns the results. It is useful for answering
180
+ questions or retrieving information from the web.
181
+
182
+ Args:
183
+ query: The search query string to be executed.
184
+ max_results: The maximum number of search results
185
+ to retrieve. Defaults to 3.
186
+
187
+ Returns:
188
+ str: A string containing the search results. If an error occurs during
189
+ the search, an error message is returned instead.
190
+
191
+ Raises:
192
+ Exception: If there is an issue with the Tavily search tool invocation.
193
+
194
+ Example:
195
+ >>> search_tool("Who won the last match between Pakistan and Zimbabwe?")
196
+ 'Pakistan won the last match by 5 wickets.'
197
+ """
198
+ print("In search")
199
+ tavily_search = TavilySearchResults(max_results=max_results)
200
+ try:
201
+ return tavily_search.invoke(query)
202
+ except Exception as e:
203
+ return f"Error performing search: {e}"
204
+
205
+
206
+
207
+ # Fetches document score
208
+ def scoreDocuments(docs, query, embedding_model, threshold=0.7):
209
+ """
210
+ Scores the relevance of documents to the query using cosine similarity.
211
+
212
+ Args:
213
+ docs: List of retrieved documents.
214
+ query: The user query.
215
+ embedding_model: Instance of HuggingFaceEmbeddings for generating embeddings.
216
+ threshold: Minimum relevance score to consider documents relevant.
217
+
218
+ Returns:
219
+ bool: Whether the documents are relevant based on the threshold.
220
+ list: List of relevance scores.
221
+ """
222
+ # Generate embedding for the query
223
+ query_embedding = embedding_model.embed_query(query)
224
+
225
+ # Generate embeddings for each document
226
+ doc_embeddings = [embedding_model.embed_query(doc.page_content) for doc in docs]
227
+
228
+ # Compute cosine similarity scores
229
+ scores = [cosine_similarity([query_embedding], [doc_embedding])[0][0] for doc_embedding in doc_embeddings]
230
+
231
+ # Check if all scores meet the relevance threshold
232
+ is_relevant = all(score >= threshold for score in scores)
233
+ return is_relevant, scores
234
+
235
+
236
+ # Augments the prompt
237
+ def augmentPrompt(context: str, query: str) -> str:
238
+ """
239
+ Combines the system-level prompt with the user's query and the relevant document context.
240
+
241
+ Args:
242
+ context: The retrieved document context for the query.
243
+ query: The user's original query.
244
+
245
+ Returns:
246
+ str: The full prompt for the LLM, including system instructions and query context.
247
+ """
248
+ prompt = f"""
249
+ {SYS_PROMPT}
250
+
251
+ The user asked: {query}
252
+
253
+ The relevant context is:
254
+
255
+ {context}
256
+ """
257
+
258
+ return prompt
259
+
260
+
261
+ # Tool Definition
262
+ @tool
263
+ def doc_query_tool(query: str):
264
+ """
265
+ Fetches relevant context from Pinecone, scores relevance, and handles query refinement if needed.
266
+ Invokes the Groq LLM for generating responses.
267
+
268
+ Args:
269
+ query: The user's query.
270
+
271
+ Returns:
272
+ str: The response generated by the LLM based on the provided or refined query.
273
+ """
274
+ print("In doc_query")
275
+ # Retrieve relevant documents using LangChain's Pinecone integration
276
+ retriever = vector_store.as_retriever(search_kwargs={"k": 3})
277
+ retrieved_docs = retriever.get_relevant_documents(query)
278
+
279
+ # Score documents for relevance
280
+ is_relevant, scores = scoreDocuments(retrieved_docs, query, embedding_model, threshold=0.5)
281
+ if is_relevant:
282
+ print("In is_relevant")
283
+ # Generate prompt with relevant context
284
+ context = ''.join(f'## Chunk {i}:\n\n{doc.page_content}\n\n' for i, doc in enumerate(retrieved_docs))
285
+ prompt = augmentPrompt(context, query)
286
+ response = llm.invoke([HumanMessage(content=prompt)])
287
+ # return {"messages": [response]}
288
+ if context:
289
+ print(f"context = {context}")
290
+ return response
291
+
292
+ else:
293
+ # Rewrite the query using the LLM
294
+ print("In query rewrite")
295
+ chat_model = ChatGroq(model="llama3-8b-8192", api_key=os.environ["GROQ_API_KEY"])
296
+ rewrite_msg = [
297
+ HumanMessage(
298
+ content=f""" \n
299
+ Look at the input and try to reason about the underlying semantic intent/meaning. \n
300
+ Here is the initial question:
301
+ \n ------- \n
302
+ {query}
303
+ \n ------- \n
304
+ Formulate an improved question: """,
305
+ )
306
+ ]
307
+ rewritten_query = chat_model.invoke(rewrite_msg)
308
+
309
+ # # Fetch documents again with the rewritten query
310
+ new_retrieved_docs = retriever.get_relevant_documents(rewritten_query.content)
311
+
312
+ # Generate prompt with the new context
313
+ new_context = ''.join(f'## Chunk {i}:\n\n{doc.page_content}\n\n' for i, doc in enumerate(new_retrieved_docs))
314
+ new_prompt = augmentPrompt(new_context, rewritten_query.content)
315
+ response = llm.invoke([HumanMessage(content=new_prompt)])
316
+ if new_context:
317
+ print(f"new_context = {new_context}")
318
+ return response
319
+
320
+ @tool
321
+ def general_answer_tool(query: str):
322
+ """Tool for handling non-specific queries (e.g., facts or definitions)."""
323
+ print("In general")
324
+ # query = state["messages"][-1].content.lower()
325
+ response = llm.invoke([HumanMessage(content=f"Answer the following general query: {query}")])
326
+ return response
327
+
328
+
329
+ # LangGraph and Nodes/Agents Setup
330
+ members = ['doc_query', 'tavilysearch', 'general']
331
+ options = members + ["FINISH"]
332
+
333
+ system_prompt = """
334
+ You are a supervisor tasked with managing a conversation between the following workers: {members}.
335
+ Given the following user request, respond with the worker to act next.
336
+ Each worker will perform a task and respond with their results and status.
337
+ When finished, respond with FINISH.
338
+ """
339
+
340
+
341
+ class Router(TypedDict):
342
+ next: Literal['doc_query', 'tavilysearch', 'general', "FINISH"]
343
+
344
+ llm = ChatGroq(temperature=0, model_name="mixtral-8x7b-32768", api_key=os.environ['GROQ_API_KEY'])
345
+
346
+ prompt = ChatPromptTemplate.from_messages(
347
+ [
348
+ ("system", system_prompt),
349
+ MessagesPlaceholder(variable_name="messages"),
350
+ (
351
+ "system",
352
+ "Given the conversation above, who should act next?"
353
+ " Or should we FINISH? Select one of: {options}",
354
+ ),
355
+ ]
356
+ ).partial(options=options, members=", ".join(members))
357
+
358
+
359
+ # Supervisor Node Setup
360
+ def supervisor_node(state: MessagesState) -> Command[Literal['doc_query', 'tavilysearch', 'general', "__end__"]]:
361
+ messages = [{"role": "system", "content": system_prompt}] + state["messages"]
362
+ # print(messages)
363
+ response = llm.with_structured_output(Router).invoke(messages)
364
+ # print(response)
365
+ goto = response["next"]
366
+ if goto == "FINISH":
367
+ goto = END
368
+ return Command(goto=goto)
369
+
370
+ # Agents Setup
371
+ # Math Agent
372
+ # math_prompt = "Peform arithmetic operations using your given tools"
373
+ math_agent = create_react_agent(llm,
374
+ tools=[multiply, add, subtract, divide, exponentiate],
375
+ state_modifier="You will ONLY DO math.")
376
+
377
+ def math_node(state: MessagesState) -> Command[Literal["supervisor"]]:
378
+ result = math_agent.invoke(state)
379
+ return Command(
380
+ update={"messages": [HumanMessage(content=result["messages"][-1].content, name="math")]},
381
+ goto="supervisor",
382
+ )
383
+
384
+ # Search Agent
385
+ search_agent = create_react_agent(llm,
386
+ tools=[search_tool],
387
+ state_modifier="You are a researcher. DO NOT do any math.")
388
+
389
+ def search_node(state: MessagesState) -> Command[Literal["supervisor"]]:
390
+ result = search_agent.invoke(state)
391
+ return Command(
392
+ update={"messages": [HumanMessage(content=result["messages"][-1].content, name="tavilysearch")]},
393
+ goto="supervisor",
394
+ )
395
+
396
+ # Document Query Agent
397
+ doc_query_agent = create_react_agent(llm,
398
+ tools=[doc_query_tool],
399
+ state_modifier="You will only look into retreived documents for answer. DO NOT search on internet.")
400
+
401
+ def doc_query_node(state: MessagesState) -> Command[Literal["supervisor"]]:
402
+ result = doc_query_agent.invoke(state)
403
+ return Command(
404
+ update={"messages": [HumanMessage(content=result["messages"][-1].content, name="doc_query")]},
405
+ goto="supervisor",
406
+ )
407
+
408
+ # # General Answer Agent
409
+ general_agent = create_react_agent(llm,
410
+ tools=[general_answer_tool],
411
+ state_modifier="You will ONLY GIVE answer to the query if no else tool can give an answer. DO NOT do math.")
412
+
413
+ def general_node(state: MessagesState) -> Command[Literal["supervisor"]]:
414
+ print("In general_node")
415
+ # print(state)
416
+ result = general_agent.invoke(state)
417
+ return Command(
418
+ update={"messages": [HumanMessage(content=result["messages"][-1].content, name="general")]},
419
+ goto="supervisor",
420
+ )
421
+
422
+ # Build the StateGraph
423
+ builder = StateGraph(MessagesState)
424
+ builder.add_edge(START, "supervisor")
425
+ builder.add_node("supervisor", supervisor_node)
426
+ # builder.add_node("math", math_node)
427
+ builder.add_node("tavilysearch", search_node)
428
+ builder.add_node("doc_query", doc_query_node)
429
+ builder.add_node("general", general_node)
430
+ # builder.add_edge("supervisor", END)
431
+ graph = builder.compile()
432
+
433
+
434
+ # Gardio App Creation
435
+ def convertQueryToInputsFormat(query):
436
+ return {"messages": [('human', query)]}
437
+
438
+
439
+ async def getFinalGraphResponse(graph, inputs, stream_mode="values"):
440
+ final_chunk = None
441
+ async for chunk in graph.astream(inputs, stream_mode=stream_mode):
442
+ final_chunk = chunk
443
+ return final_chunk
444
+
445
+ def getResponse(input_text):
446
+ inputs = convertQueryToInputsFormat(input_text)
447
+ try:
448
+ loop = asyncio.get_event_loop()
449
+ # Handle cases where no loop exists
450
+ except RuntimeError:
451
+ loop = asyncio.new_event_loop()
452
+ asyncio.set_event_loop(loop)
453
+
454
+ final_output = loop.run_until_complete(getFinalGraphResponse(graph, inputs))
455
+
456
+ if final_output and "messages" in final_output:
457
+ response = final_output["messages"][-1].content
458
+ return response
459
+ else:
460
+ return "No response received."
461
+
462
+ # Create the Gradio Interface
463
+ iface = gr.Interface(
464
+ fn=getResponse,
465
+ inputs=gr.Textbox(
466
+ label="Enter your question",
467
+ placeholder="Type your question here..."
468
+ ),
469
+ outputs="textbox",
470
+ title="Researcher and Doc-Query Handler",
471
+ description=(
472
+ "Ask a question about NetSol Financial Report or internet related query "
473
+ "This assistant looks up relevant documents if needed and then answers your question."
474
+ ),
475
+ examples=[
476
+ ["What are the main objectives outlined in NETSOL's mission statement?"],
477
+ ["Who won first t20 match between Pakistan and Zimbabwe?"],
478
+ ["Who is the CEO of Huawei?"]
479
+ ],
480
+ theme=gr.themes.Soft(),
481
+ allow_flagging="never"
482
+ )
483
+
484
+ iface.launch(share=False)
requirements.txt ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langgraph
2
+ langgraph-sdk
3
+ langgraph-checkpoint-sqlite
4
+ langsmith
5
+ langchain-community
6
+ langchain-core
7
+ langchain-openai
8
+ langchain-huggingface
9
+ langchain-pinecone
10
+ notebook
11
+ tavily-python
12
+ wikipedia
13
+ trustcall
14
+ langgraph-cli
15
+ langchain-groq
16
+ langchain-anthropic
17
+ python-dotenv
18
+ pydantic
19
+ unstructured[all-docs]
20
+ pinecone[grpc]
21
+ pymupdf
22
+ ragas
23
+ datasets
24
+ gradio
25
+ langserve
26
+ pymupdf