swayam-the-coder commited on
Commit
c241f79
·
verified ·
1 Parent(s): 5e4ac43

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +534 -0
app.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from pathlib import Path
4
+ from tempfile import TemporaryDirectory
5
+ from langchain_core.messages import BaseMessage, HumanMessage
6
+ from typing import Annotated, List, Optional, Dict
7
+ from typing_extensions import TypedDict
8
+ from langchain_community.document_loaders import WebBaseLoader
9
+ from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from langchain_core.tools import tool
11
+ from langchain.agents import AgentExecutor, create_openai_functions_agent
12
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
13
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
14
+ from langchain_openai import ChatOpenAI
15
+ from langgraph.graph import END, StateGraph, START
16
+ import functools
17
+ import operator
18
+ import logging
19
+ import time
20
+ from tenacity import retry, stop_after_attempt, wait_exponential
21
+
22
+ # Set up logging
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Initialize temporary directory
27
+ if 'working_directory' not in st.session_state:
28
+ _TEMP_DIRECTORY = TemporaryDirectory()
29
+ st.session_state.working_directory = Path(_TEMP_DIRECTORY.name)
30
+
31
+ WORKING_DIRECTORY = st.session_state.working_directory
32
+
33
+ # Streamlit UI
34
+ st.set_page_config(page_title="MARS: Multi-Agent Report Synthesizer", layout="wide")
35
+
36
+ # Custom CSS for styling
37
+ st.markdown("""
38
+ <style>
39
+ body {
40
+ background-color: #f5f5f5;
41
+ color: #333333;
42
+ font-family: 'Comic Sans MS', 'Comic Sans', cursive;
43
+ }
44
+ .report-container {
45
+ border-radius: 10px;
46
+ background-color: #ffcccb;
47
+ padding: 20px;
48
+ }
49
+ .sidebar .sidebar-content {
50
+ background-color: #333333;
51
+ color: #ffffff;
52
+ }
53
+ .stButton button {
54
+ background-color: #ff6347;
55
+ color: #ffffff;
56
+ border-radius: 5px;
57
+ font-size: 18px;
58
+ padding: 10px 20px;
59
+ font-weight: bold;
60
+ }
61
+ .stTextInput input {
62
+ border-radius: 5px;
63
+ border: 2px solid #ff6347;
64
+ font-size: 16px;
65
+ padding: 10px;
66
+ width: 100%;
67
+ }
68
+ .stTextInput label {
69
+ font-size: 18px;
70
+ font-weight: bold;
71
+ color: #333333;
72
+ }
73
+ .stSelectbox label, .stDownloadButton label {
74
+ font-size: 18px;
75
+ font-weight: bold;
76
+ color: #333333;
77
+ }
78
+ .stSelectbox div, .stDownloadButton div {
79
+ background-color: #ffcccb;
80
+ color: #333333;
81
+ border-radius: 5px;
82
+ padding: 10px;
83
+ font-size: 16px;
84
+ }
85
+ </style>
86
+ """, unsafe_allow_html=True)
87
+
88
+ st.title("🚀 MARS: Multi-agent Report Synthesizer 🤖")
89
+ st.sidebar.title("📋 Instructions")
90
+ st.sidebar.write("""
91
+ 1. Enter your query in the input box.
92
+ 2. Marvin AI will assign tasks to different teams.
93
+ 3. You can see the progress and download the final report.
94
+ 4. Use the buttons to list and download output files.
95
+ """)
96
+
97
+ # Input fields for API keys
98
+ openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password")
99
+ tavily_api_key = st.sidebar.text_input("Tavily API Key", type="password")
100
+
101
+ # Store the API keys in the session state
102
+ if openai_api_key:
103
+ os.environ["OPENAI_API_KEY"] = openai_api_key
104
+ if tavily_api_key:
105
+ os.environ["TAVILY_API_KEY"] = tavily_api_key
106
+
107
+ # Check if the API keys are set
108
+ if not os.getenv("OPENAI_API_KEY"):
109
+ st.error("OpenAI API Key is required.")
110
+ if not os.getenv("TAVILY_API_KEY"):
111
+ st.error("Tavily API Key is required.")
112
+
113
+ # Define tools
114
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
115
+ def tavily_search_with_retry(*args, **kwargs):
116
+ return TavilySearchResults(*args, **kwargs)
117
+
118
+ tavily_tool = tavily_search_with_retry(max_results=5)
119
+
120
+ @tool
121
+ def scrape_webpages(urls: List[str]) -> str:
122
+ """Use requests and bs4 to scrape the provided web pages for detailed information."""
123
+ try:
124
+ loader = WebBaseLoader(urls)
125
+ docs = loader.load()
126
+ return "\n\n".join(
127
+ [
128
+ f'\n{doc.page_content}\n'
129
+ for doc in docs
130
+ ]
131
+ )
132
+ except Exception as e:
133
+ logger.error(f"Error in scrape_webpages: {str(e)}")
134
+ return f"Error occurred while scraping webpages: {str(e)}"
135
+
136
+ @tool
137
+ def create_outline(
138
+ points: Annotated[List[str], "List of main points or sections."],
139
+ file_name: Annotated[str, "File path to save the outline."],
140
+ ) -> Annotated[str, "Path of the saved outline file."]:
141
+ """Create and save an outline."""
142
+ try:
143
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
144
+ for i, point in enumerate(points):
145
+ file.write(f"{i + 1}. {point}\n")
146
+ return f"Outline saved to {file_name}"
147
+ except Exception as e:
148
+ logger.error(f"Error in create_outline: {str(e)}")
149
+ return f"Error occurred while creating outline: {str(e)}"
150
+
151
+ @tool
152
+ def read_document(
153
+ file_name: Annotated[str, "File path to save the document."],
154
+ start: Annotated[Optional[int], "The start line. Default is 0"] = None,
155
+ end: Annotated[Optional[int], "The end line. Default is None"] = None,
156
+ ) -> str:
157
+ """Read the specified document."""
158
+ try:
159
+ with (WORKING_DIRECTORY / file_name).open("r") as file:
160
+ lines = file.readlines()
161
+ if start is not None:
162
+ start = 0
163
+ return "\n".join(lines[start:end])
164
+ except Exception as e:
165
+ logger.error(f"Error in read_document: {str(e)}")
166
+ return f"Error occurred while reading document: {str(e)}"
167
+
168
+ @tool
169
+ def write_document(
170
+ content: Annotated[str, "Text content to be written into the document."],
171
+ file_name: Annotated[str, "File path to save the document."],
172
+ ) -> Annotated[str, "Path of the saved document file."]:
173
+ """Create and save a text document."""
174
+ try:
175
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
176
+ file.write(content)
177
+ return f"Document saved to {file_name}"
178
+ except Exception as e:
179
+ logger.error(f"Error in write_document: {str(e)}")
180
+ return f"Error occurred while writing document: {str(e)}"
181
+
182
+ @tool
183
+ def edit_document(
184
+ file_name: Annotated[str, "Path of the document to be edited."],
185
+ inserts: Annotated[
186
+ Dict[int, str],
187
+ "Dictionary where key is the line number (1-indexed) and value is the text to be inserted at that line.",
188
+ ],
189
+ ) -> Annotated[str, "Path of the edited document file."]:
190
+ """Edit a document by inserting text at specific line numbers."""
191
+ try:
192
+ with (WORKING_DIRECTORY / file_name).open("r") as file:
193
+ lines = file.readlines()
194
+ sorted_inserts = sorted(inserts.items())
195
+ for line_number, text in sorted_inserts:
196
+ if 1 <= line_number <= len(lines) + 1:
197
+ lines.insert(line_number - 1, text + "\n")
198
+ else:
199
+ return f"Error: Line number {line_number} is out of range."
200
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
201
+ file.writelines(lines)
202
+ return f"Document edited and saved to {file_name}"
203
+ except Exception as e:
204
+ logger.error(f"Error in edit_document: {str(e)}")
205
+ return f"Error occurred while editing document: {str(e)}"
206
+
207
+ # Define the agents and their tools
208
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
209
+
210
+ def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str) -> str:
211
+ """Create a function-calling agent and add it to the graph."""
212
+ system_prompt += """\nWork autonomously according to your specialty, using the tools available to you.
213
+ Do not ask for clarification.
214
+ Your other team members (and other teams) will collaborate with you with their own specialties.
215
+ You are chosen for a reason! You are one of the following team members: {team_members}."""
216
+ prompt = ChatPromptTemplate.from_messages(
217
+ [
218
+ ("system", system_prompt),
219
+ MessagesPlaceholder(variable_name="messages"),
220
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
221
+ ]
222
+ )
223
+ agent = create_openai_functions_agent(llm, tools, prompt)
224
+ executor = AgentExecutor(agent=agent, tools=tools)
225
+ return executor
226
+
227
+ def agent_node(state, agent, name):
228
+ try:
229
+ logger.info(f"Starting {name} agent")
230
+ result = agent.invoke(state)
231
+ logger.info(f"{name} agent completed with result: {result}")
232
+ return {"messages": [HumanMessage(content=result["output"], name=name)]}
233
+ except Exception as e:
234
+ logger.error(f"Error in {name} agent: {str(e)}")
235
+ return {"messages": [HumanMessage(content=f"Error occurred in {name} agent: {str(e)}", name=name)]}
236
+
237
+ def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> str:
238
+ """An LLM-based router."""
239
+ options = ["FINISH"] + members
240
+ function_def = {
241
+ "name": "route",
242
+ "description": "Select the next role.",
243
+ "parameters": {
244
+ "title": "routeSchema",
245
+ "type": "object",
246
+ "properties": {
247
+ "next": {
248
+ "title": "Next",
249
+ "anyOf": [
250
+ {"enum": options},
251
+ ],
252
+ },
253
+ },
254
+ "required": ["next"],
255
+ },
256
+ }
257
+ system_prompt += "\nEnsure that you direct the workflow to completion. If no progress is being made, or if the task seems complete, choose FINISH."
258
+ prompt = ChatPromptTemplate.from_messages(
259
+ [
260
+ ("system", system_prompt),
261
+ MessagesPlaceholder(variable_name="messages"),
262
+ ("system", "Given the conversation above, who should act next? Or should we FINISH? Select one of: {options}"),
263
+ ]
264
+ ).partial(options=str(options), team_members=", ".join(members))
265
+ return (
266
+ prompt
267
+ | llm.bind_functions(functions=[function_def], function_call="route")
268
+ | JsonOutputFunctionsParser()
269
+ )
270
+
271
+ # ResearchTeam graph state
272
+ class ResearchTeamState(TypedDict):
273
+ messages: Annotated[List[BaseMessage], operator.add]
274
+ team_members: List[str]
275
+ next: str
276
+
277
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
278
+
279
+ search_agent = create_agent(
280
+ llm,
281
+ [tavily_tool],
282
+ "You are a research assistant who can search for up-to-date info using the tavily search engine.",
283
+ )
284
+ search_node = functools.partial(agent_node, agent=search_agent, name="Search")
285
+
286
+ research_agent = create_agent(
287
+ llm,
288
+ [scrape_webpages],
289
+ "You are a research assistant who can scrape specified urls for more detailed information using the scrape_webpages function.",
290
+ )
291
+ research_node = functools.partial(agent_node, agent=research_agent, name="WebScraper")
292
+
293
+ supervisor_agent = create_team_supervisor(
294
+ llm,
295
+ "You are a supervisor tasked with managing a conversation between the"
296
+ " following workers: Search, WebScraper. Given the following user request,"
297
+ " respond with the worker to act next. Each worker will perform a"
298
+ " task and respond with their results and status. When finished,"
299
+ " respond with FINISH.",
300
+ ["Search", "WebScraper"],
301
+ )
302
+
303
+ research_graph = StateGraph(ResearchTeamState)
304
+ research_graph.add_node("Search", search_node)
305
+ research_graph.add_node("WebScraper", research_node)
306
+ research_graph.add_node("supervisor", supervisor_agent)
307
+
308
+ # Define the control flow
309
+ research_graph.add_edge("Search", "supervisor")
310
+ research_graph.add_edge("WebScraper", "supervisor")
311
+ research_graph.add_conditional_edges(
312
+ "supervisor",
313
+ lambda x: x["next"],
314
+ {"Search": "Search", "WebScraper": "WebScraper", "FINISH": END},
315
+ )
316
+
317
+ research_graph.add_edge(START, "supervisor")
318
+ chain = research_graph.compile()
319
+
320
+ def enter_chain(message: str):
321
+ results = {
322
+ "messages": [HumanMessage(content=message)],
323
+ }
324
+ return results
325
+
326
+ research_chain = enter_chain | chain
327
+
328
+ # Document writing team graph state
329
+ class DocWritingState(TypedDict):
330
+ messages: Annotated[List[BaseMessage], operator.add]
331
+ team_members: str
332
+ next: str
333
+ current_files: str
334
+
335
+ def prelude(state):
336
+ written_files = []
337
+ if not WORKING_DIRECTORY.exists():
338
+ WORKING_DIRECTORY.mkdir()
339
+ try:
340
+ written_files = [
341
+ f.relative_to(WORKING_DIRECTORY) for f in WORKING_DIRECTORY.rglob("*")
342
+ ]
343
+ except Exception:
344
+ pass
345
+ if not written_files:
346
+ return {**state, "current_files": "No files written."}
347
+ return {
348
+ **state,
349
+ "current_files": "\nBelow are files your team has written to the directory:\n"
350
+ + "\n".join([f" - {f}" for f in written_files]),
351
+ }
352
+
353
+ doc_writer_agent = create_agent(
354
+ llm,
355
+ [write_document, edit_document, read_document],
356
+ "You are an expert writing a research document.\n"
357
+ "Below are files currently in your directory:\n{current_files}",
358
+ )
359
+ context_aware_doc_writer_agent = prelude | doc_writer_agent
360
+ doc_writing_node = functools.partial(
361
+ agent_node, agent=context_aware_doc_writer_agent, name="DocWriter"
362
+ )
363
+
364
+ note_taking_agent = create_agent(
365
+ llm,
366
+ [create_outline, read_document],
367
+ "You are an expert senior researcher tasked with writing a paper outline and"
368
+ " taking notes to craft a perfect paper.{current_files}",
369
+ )
370
+ context_aware_note_taking_agent = prelude | note_taking_agent
371
+ note_taking_node = functools.partial(
372
+ agent_node, agent=context_aware_note_taking_agent, name="NoteTaker"
373
+ )
374
+
375
+ chart_generating_agent = create_agent(
376
+ llm,
377
+ [read_document],
378
+ "You are a data viz expert tasked with generating charts for a research project."
379
+ "{current_files}",
380
+ )
381
+ context_aware_chart_generating_agent = prelude | chart_generating_agent
382
+ chart_generating_node = functools.partial(
383
+ agent_node, agent=context_aware_note_taking_agent, name="ChartGenerator"
384
+ )
385
+
386
+ doc_writing_supervisor = create_team_supervisor(
387
+ llm,
388
+ "You are a supervisor tasked with managing a conversation between the"
389
+ " following workers: {team_members}. Given the following user request,"
390
+ " respond with the worker to act next. Each worker will perform a"
391
+ " task and respond with their results and status. When finished,"
392
+ " respond with FINISH.",
393
+ ["DocWriter", "NoteTaker", "ChartGenerator"],
394
+ )
395
+
396
+ authoring_graph = StateGraph(DocWritingState)
397
+ authoring_graph.add_node("DocWriter", doc_writing_node)
398
+ authoring_graph.add_node("NoteTaker", note_taking_node)
399
+ authoring_graph.add_node("ChartGenerator", chart_generating_node)
400
+ authoring_graph.add_node("supervisor", doc_writing_supervisor)
401
+
402
+ authoring_graph.add_edge("DocWriter", "supervisor")
403
+ authoring_graph.add_edge("NoteTaker", "supervisor")
404
+ authoring_graph.add_edge("ChartGenerator", "supervisor")
405
+ authoring_graph.add_conditional_edges(
406
+ "supervisor",
407
+ lambda x: x["next"],
408
+ {
409
+ "DocWriter": "DocWriter",
410
+ "NoteTaker": "NoteTaker",
411
+ "ChartGenerator": "ChartGenerator",
412
+ "FINISH": END,
413
+ },
414
+ )
415
+
416
+ authoring_graph.add_edge(START, "supervisor")
417
+ chain = authoring_graph.compile()
418
+
419
+ def enter_chain(message: str, members: List[str]):
420
+ results = {
421
+ "messages": [HumanMessage(content=message)],
422
+ "team_members": ", ".join(members),
423
+ }
424
+ return results
425
+
426
+ authoring_chain = (
427
+ functools.partial(enter_chain, members=authoring_graph.nodes)
428
+ | authoring_graph.compile()
429
+ )
430
+
431
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
432
+
433
+ supervisor_node = create_team_supervisor(
434
+ llm,
435
+ "You are a supervisor tasked with managing a conversation between the"
436
+ " following teams: {team_members}. Given the following user request,"
437
+ " respond with the worker to act next. Each worker will perform a"
438
+ " task and respond with their results and status. Make sure each team is used atleast once. When finished,"
439
+ " respond with FINISH.",
440
+ ["ResearchTeam", "PaperWritingTeam"],
441
+ )
442
+
443
+ class State(TypedDict):
444
+ messages: Annotated[List[BaseMessage], operator.add]
445
+ next: str
446
+
447
+ def get_last_message(state: State) -> str:
448
+ return state["messages"][-1].content
449
+
450
+ def join_graph(response: dict):
451
+ return {"messages": [response["messages"][-1]]}
452
+
453
+ super_graph = StateGraph(State)
454
+ super_graph.add_node("ResearchTeam", get_last_message | research_chain | join_graph)
455
+ super_graph.add_node("PaperWritingTeam", get_last_message | authoring_chain | join_graph)
456
+ super_graph.add_node("supervisor", supervisor_node)
457
+
458
+ super_graph.add_edge("ResearchTeam", "supervisor")
459
+ super_graph.add_edge("PaperWritingTeam", "supervisor")
460
+ super_graph.add_conditional_edges(
461
+ "supervisor",
462
+ lambda x: x["next"],
463
+ {
464
+ "PaperWritingTeam": "PaperWritingTeam",
465
+ "ResearchTeam": "ResearchTeam",
466
+ "FINISH": END,
467
+ },
468
+ )
469
+ super_graph.add_edge(START, "supervisor")
470
+ super_graph = super_graph.compile()
471
+
472
+ input_text = st.text_input("Enter your query:")
473
+
474
+ if input_text and os.getenv("OPENAI_API_KEY") and os.getenv("TAVILY_API_KEY"):
475
+ st.markdown("### 🛠️ Task Progress")
476
+ start_time = time.time()
477
+ max_execution_time = 300 # 5 minutes
478
+
479
+ try:
480
+ for s in super_graph.stream(
481
+ {
482
+ "messages": [
483
+ HumanMessage(
484
+ content=input_text
485
+ )
486
+ ],
487
+ },
488
+ {"recursion_limit": 300}, # Increased recursion limit
489
+ ):
490
+ if "__end__" not in s:
491
+ st.write(s)
492
+ st.write("---")
493
+
494
+ # Check for timeout
495
+ if time.time() - start_time > max_execution_time:
496
+ st.warning("Execution time exceeded. Terminating the process.")
497
+ break
498
+ except Exception as e:
499
+ st.error(f"An error occurred: {str(e)}")
500
+ logger.error(f"Error in super_graph execution: {str(e)}")
501
+
502
+ if st.button("List Output Files"):
503
+ files = os.listdir(WORKING_DIRECTORY)
504
+ if files:
505
+ st.write("### 📂 Files in working directory:")
506
+ for file in files:
507
+ st.write(f"📄 {file}")
508
+ else:
509
+ st.write("No files found in the working directory.")
510
+
511
+ output_files = os.listdir(WORKING_DIRECTORY)
512
+ if output_files:
513
+ output_file = st.selectbox("Select an output file to download:", output_files)
514
+
515
+ if st.button("Download Output Document"):
516
+ file_path = WORKING_DIRECTORY / output_file
517
+ if file_path.exists():
518
+ with file_path.open("rb") as file:
519
+ st.download_button(
520
+ label="📥 Download Output Document",
521
+ data=file,
522
+ file_name=output_file,
523
+ )
524
+ else:
525
+ st.write("Output document not found.")
526
+ else:
527
+ st.write("No output files available for download.")
528
+
529
+ # Cleanup
530
+ if st.button("Clear Working Directory"):
531
+ for file in WORKING_DIRECTORY.iterdir():
532
+ if file.is_file():
533
+ file.unlink()
534
+ st.success("Working directory cleared.")