AseemD commited on
Commit
9dcd5de
·
verified ·
1 Parent(s): 6be7795

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import wikipediaapi
4
+ import gradio as gr
5
+ from groq import Groq
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_openai import OpenAIEmbeddings
8
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
9
+
10
+ from utils.context import system_prompt
11
+ os.environ["GROQ_API_KEY"] = "gsk..."
12
+
13
+ # Agent Class
14
+ class Agent:
15
+ def __init__(self, client, system):
16
+ self.client = client
17
+ self.system = system
18
+ self.memory = []
19
+ # If there is no memory, initialize it with the system message
20
+ if self.memory is not None:
21
+ self.memory = [{"role": "system", "content": self.system}]
22
+
23
+ def __call__(self, message=""):
24
+ if message:
25
+ self.memory.append({"role": "user", "content": message})
26
+ result = self.execute()
27
+ self.memory.append({"role": "assistant", "content": result})
28
+ return result
29
+
30
+ def execute(self):
31
+ completion = client.chat.completions.create(
32
+ messages = self.memory,
33
+ model="llama3-70b-8192",
34
+ )
35
+ return completion.choices[0].message.content
36
+
37
+ # Gloabal variables
38
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
39
+ wiki = wikipediaapi.Wikipedia(language='en', user_agent="aseem" )
40
+ embeddings = OpenAIEmbeddings()
41
+ faiss_store = None
42
+
43
+ # Utils/Tools for the agent
44
+ def calculate(operation):
45
+ return eval(operation)
46
+
47
+ def wikipedia_search(query, advanced_query, advanced_search=False, top_k=5):
48
+ global faiss_store
49
+ page = wiki.page(query)
50
+
51
+ # Check if the page exists
52
+ if page.exists():
53
+ if advanced_search:
54
+ # Get the full content of the Wikipedia page
55
+ content = page.text
56
+ # Split the content into chunks
57
+ chunks = chunk_text(content)
58
+ # Store the chunks in FAISS
59
+ faiss_store = store_in_faiss(chunks)
60
+ # Retrieve the top-k relevant chunks
61
+ top_k_documents = retrieve_top_k(advanced_query, top_k)
62
+ # Return the retrieved documents
63
+ return f"Context: {" ".join(top_k_documents)}\n"
64
+ else:
65
+ return f"Summary: {page.summary}\n"
66
+ else:
67
+ return f"The page '{query}' does not exist on Wikipedia."
68
+
69
+
70
+ def chunk_text(text, chunk_size=512, chunk_overlap=50):
71
+ """
72
+ Uses LangChain's RecursiveCharacterTextSplitter to chunk the text.
73
+ """
74
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
75
+ chunks = splitter.split_text(text)
76
+ return chunks
77
+
78
+ def store_in_faiss(chunks):
79
+ """
80
+ Stores the chunks in a FAISS vector store.
81
+ """
82
+ vector_store = FAISS.from_texts(chunks, embeddings)
83
+ return vector_store
84
+
85
+ def retrieve_top_k(query, top_k=5):
86
+ """
87
+ Retrieves the top-k most relevant chunks from FAISS.
88
+ """
89
+ if faiss_store is None:
90
+ return "No vector data available. Perform advanced search first."
91
+
92
+ # Retrieve top-k documents
93
+ docs_and_scores = faiss_store.similarity_search_with_score(query, top_k)
94
+ top_k_chunks = [doc.page_content for doc, score in docs_and_scores]
95
+ return top_k_chunks
96
+
97
+ # Automatic execution of the agent
98
+ def run_agent(max_iterations=10, query: str = "", display_reasoning=True):
99
+ agent = Agent(client=client, system=system_prompt)
100
+ tools = ["calculate", "wikipedia_search"]
101
+ next_prompt = query
102
+ iteration = 0
103
+ steps = 1
104
+ partial_results = ""
105
+
106
+ while iteration < max_iterations:
107
+ iteration += 1
108
+ result = agent(next_prompt)
109
+
110
+ if display_reasoning:
111
+ partial_results += f" -------- (Step {steps}) -------- \n"
112
+ steps += 1
113
+ partial_results += result + "\n\n"
114
+ yield partial_results
115
+
116
+ if "Thought" in result and "Action" in result:
117
+ action = re.findall(r"Action: ([a-z_]+): (.+)", result, re.IGNORECASE)
118
+ chosen_tool = action[0][0]
119
+ args = action[0][1]
120
+ if chosen_tool in tools:
121
+ if chosen_tool == "calculate":
122
+ tool_result = eval(f"{chosen_tool}({'args'})")
123
+ next_prompt = f"Observation: {tool_result}"
124
+ else:
125
+ tool_result = eval(f"{chosen_tool}({args})")
126
+ next_prompt = f"Observation: {tool_result}"
127
+ else:
128
+ next_prompt = "Observation: Tool not found"
129
+
130
+ if display_reasoning:
131
+ partial_results += f" -------- (Step {steps}) -------- \n"
132
+ steps += 1
133
+ partial_results += next_prompt[:100] + " ..." + "\n\n"
134
+ yield partial_results
135
+ continue
136
+
137
+ if "Answer" in result:
138
+ if display_reasoning:
139
+ yield partial_results
140
+ else:
141
+ partial_results += result.split("Answer:")[-1].strip()
142
+ yield partial_results
143
+ break
144
+
145
+ if iteration >= max_iterations:
146
+ partial_text += "\nThe Wikipedia AI Agent is likely hallucinating. Please try again :("
147
+ yield partial_text
148
+
149
+ def generate_response_stream(message, show_reasoning):
150
+ # If show_reasoning = True, we'll show all the partial steps
151
+ # If show_reasoning = False, we only yield the final answer
152
+ yield from run_agent(query=message, display_reasoning=show_reasoning)
153
+
154
+ def main():
155
+ interface = gr.Interface(
156
+ fn=generate_response_stream,
157
+ inputs=[
158
+ gr.Textbox(label="Ask your question here:"),
159
+ gr.Checkbox(label="Show reasoning")
160
+ ],
161
+ outputs=gr.Textbox(label="Agent Output"),
162
+ title="Wikipedia AI Agent",
163
+ description= (
164
+ "Ask a question to the Wikipedia AI Agent."
165
+ "For eg: \n"
166
+ "- \"What is the weight of a tiger?\" \n"
167
+ "- \"Why are fiber optic cables so fragile?\" \n"
168
+ "- \"How does an internal combustion engine work?\" \n"
169
+ )
170
+ )
171
+ interface.launch()
172
+
173
+
174
+ if __name__ == "__main__":
175
+ main()