swayam-the-coder commited on
Commit
c6c55e6
·
verified ·
1 Parent(s): 3d68068

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +538 -0
app.py ADDED
@@ -0,0 +1,538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ from pathlib import Path
4
+ from tempfile import TemporaryDirectory
5
+ from langchain_core.messages import BaseMessage, HumanMessage
6
+ from typing import Annotated, List, Optional, Dict
7
+ from typing_extensions import TypedDict
8
+ from langchain_community.document_loaders import WebBaseLoader
9
+ from langchain_community.tools.tavily_search import TavilySearchResults
10
+ from langchain_core.tools import tool
11
+ from langchain.agents import AgentExecutor, create_openai_functions_agent
12
+ from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
13
+ from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
14
+ from langchain_openai import ChatOpenAI
15
+ from langgraph.graph import END, StateGraph, START
16
+ import functools
17
+ import operator
18
+ import logging
19
+ import time
20
+ from tenacity import retry, stop_after_attempt, wait_exponential
21
+
22
+ # Set up logging
23
+ logging.basicConfig(level=logging.INFO)
24
+ logger = logging.getLogger(__name__)
25
+
26
+ # Streamlit UI
27
+ st.set_page_config(page_title="MARS: Multi-Agent Report Synthesizer", layout="wide")
28
+
29
+ # Custom CSS for styling
30
+ st.markdown("""
31
+ <style>
32
+ body {
33
+ background-color: #f5f5f5;
34
+ color: #333333;
35
+ font-family: 'Comic Sans MS', 'Comic Sans', cursive;
36
+ }
37
+ .report-container {
38
+ border-radius: 10px;
39
+ background-color: #ffcccb;
40
+ padding: 20px;
41
+ }
42
+ .sidebar .sidebar-content {
43
+ background-color: #333333;
44
+ color: #ffffff;
45
+ }
46
+ .stButton button {
47
+ background-color: #ff6347;
48
+ color: #ffffff;
49
+ border-radius: 5px;
50
+ font-size: 18px;
51
+ padding: 10px 20px;
52
+ font-weight: bold;
53
+ }
54
+ .stTextInput input {
55
+ border-radius: 5px;
56
+ border: 2px solid #ff6347;
57
+ font-size: 16px;
58
+ padding: 10px;
59
+ width: 100%;
60
+ }
61
+ .stTextInput label {
62
+ font-size: 18px;
63
+ font-weight: bold;
64
+ color: #333333;
65
+ }
66
+ .stSelectbox label, .stDownloadButton label {
67
+ font-size: 18px;
68
+ font-weight: bold;
69
+ color: #333333;
70
+ }
71
+ .stSelectbox div, .stDownloadButton div {
72
+ background-color: #ffcccb;
73
+ color: #333333;
74
+ border-radius: 5px;
75
+ padding: 10px;
76
+ font-size: 16px;
77
+ }
78
+ </style>
79
+ """, unsafe_allow_html=True)
80
+
81
+ st.title("🚀 MARS: Multi-agent Report Synthesizer 🤖")
82
+
83
+ # Sidebar for API key inputs
84
+ st.sidebar.title("API Keys")
85
+ openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password")
86
+ tavily_api_key = st.sidebar.text_input("Tavily API Key", type="password")
87
+
88
+ # Only set environment variables if API keys are provided
89
+ if openai_api_key:
90
+ os.environ["OPENAI_API_KEY"] = openai_api_key
91
+ if tavily_api_key:
92
+ os.environ["TAVILY_API_KEY"] = tavily_api_key
93
+
94
+ # Use environment variables as fallback
95
+ openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
96
+ tavily_api_key = tavily_api_key or os.getenv("TAVILY_API_KEY")
97
+
98
+ st.sidebar.title("📋 Instructions")
99
+ st.sidebar.write("""
100
+ 1. Enter your OpenAI and Tavily API keys in the fields above.
101
+ 2. Enter your query in the input box.
102
+ 3. Marvin AI will assign tasks to different teams.
103
+ 4. You can see the progress and download the final report.
104
+ 5. Use the buttons to list and download output files.
105
+ """)
106
+
107
+ # Set environment variables
108
+ os.environ["LANGCHAIN_TRACING_V2"] = "true"
109
+
110
+ # Initialize temporary directory
111
+ if 'working_directory' not in st.session_state:
112
+ _TEMP_DIRECTORY = TemporaryDirectory()
113
+ st.session_state.working_directory = Path(_TEMP_DIRECTORY.name)
114
+
115
+ WORKING_DIRECTORY = st.session_state.working_directory
116
+
117
+ # Define tools
118
+ @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
119
+ def tavily_search_with_retry(*args, **kwargs):
120
+ return TavilySearchResults(*args, **kwargs)
121
+
122
+ tavily_tool = tavily_search_with_retry(max_results=5, api_key=tavily_api_key)
123
+
124
+ @tool
125
+ def scrape_webpages(urls: List[str]) -> str:
126
+ """Use requests and bs4 to scrape the provided web pages for detailed information."""
127
+ try:
128
+ loader = WebBaseLoader(urls)
129
+ docs = loader.load()
130
+ return "\n\n".join(
131
+ [
132
+ f'\n{doc.page_content}\n'
133
+ for doc in docs
134
+ ]
135
+ )
136
+ except Exception as e:
137
+ logger.error(f"Error in scrape_webpages: {str(e)}")
138
+ return f"Error occurred while scraping webpages: {str(e)}"
139
+
140
+ @tool
141
+ def create_outline(
142
+ points: Annotated[List[str], "List of main points or sections."],
143
+ file_name: Annotated[str, "File path to save the outline."],
144
+ ) -> Annotated[str, "Path of the saved outline file."]:
145
+ """Create and save an outline."""
146
+ try:
147
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
148
+ for i, point in enumerate(points):
149
+ file.write(f"{i + 1}. {point}\n")
150
+ return f"Outline saved to {file_name}"
151
+ except Exception as e:
152
+ logger.error(f"Error in create_outline: {str(e)}")
153
+ return f"Error occurred while creating outline: {str(e)}"
154
+
155
+ @tool
156
+ def read_document(
157
+ file_name: Annotated[str, "File path to save the document."],
158
+ start: Annotated[Optional[int], "The start line. Default is 0"] = None,
159
+ end: Annotated[Optional[int], "The end line. Default is None"] = None,
160
+ ) -> str:
161
+ """Read the specified document."""
162
+ try:
163
+ with (WORKING_DIRECTORY / file_name).open("r") as file:
164
+ lines = file.readlines()
165
+ if start is not None:
166
+ start = 0
167
+ return "\n".join(lines[start:end])
168
+ except Exception as e:
169
+ logger.error(f"Error in read_document: {str(e)}")
170
+ return f"Error occurred while reading document: {str(e)}"
171
+
172
+ @tool
173
+ def write_document(
174
+ content: Annotated[str, "Text content to be written into the document."],
175
+ file_name: Annotated[str, "File path to save the document."],
176
+ ) -> Annotated[str, "Path of the saved document file."]:
177
+ """Create and save a text document."""
178
+ try:
179
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
180
+ file.write(content)
181
+ return f"Document saved to {file_name}"
182
+ except Exception as e:
183
+ logger.error(f"Error in write_document: {str(e)}")
184
+ return f"Error occurred while writing document: {str(e)}"
185
+
186
+ @tool
187
+ def edit_document(
188
+ file_name: Annotated[str, "Path of the document to be edited."],
189
+ inserts: Annotated[
190
+ Dict[int, str],
191
+ "Dictionary where key is the line number (1-indexed) and value is the text to be inserted at that line.",
192
+ ],
193
+ ) -> Annotated[str, "Path of the edited document file."]:
194
+ """Edit a document by inserting text at specific line numbers."""
195
+ try:
196
+ with (WORKING_DIRECTORY / file_name).open("r") as file:
197
+ lines = file.readlines()
198
+ sorted_inserts = sorted(inserts.items())
199
+ for line_number, text in sorted_inserts:
200
+ if 1 <= line_number <= len(lines) + 1:
201
+ lines.insert(line_number - 1, text + "\n")
202
+ else:
203
+ return f"Error: Line number {line_number} is out of range."
204
+ with (WORKING_DIRECTORY / file_name).open("w") as file:
205
+ file.writelines(lines)
206
+ return f"Document edited and saved to {file_name}"
207
+ except Exception as e:
208
+ logger.error(f"Error in edit_document: {str(e)}")
209
+ return f"Error occurred while editing document: {str(e)}"
210
+
211
+ # Define the agents and their tools
212
+ llm = ChatOpenAI(model="gpt-3.5-turbo-0125", api_key=openai_api_key)
213
+
214
+ def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str) -> str:
215
+ """Create a function-calling agent and add it to the graph."""
216
+ system_prompt += """\nWork autonomously according to your specialty, using the tools available to you.
217
+ Do not ask for clarification.
218
+ Your other team members (and other teams) will collaborate with you with their own specialties.
219
+ You are chosen for a reason! You are one of the following team members: {team_members}."""
220
+ prompt = ChatPromptTemplate.from_messages(
221
+ [
222
+ ("system", system_prompt),
223
+ MessagesPlaceholder(variable_name="messages"),
224
+ MessagesPlaceholder(variable_name="agent_scratchpad"),
225
+ ]
226
+ )
227
+ agent = create_openai_functions_agent(llm, tools, prompt)
228
+ executor = AgentExecutor(agent=agent, tools=tools)
229
+ return executor
230
+
231
+ def agent_node(state, agent, name):
232
+ try:
233
+ logger.info(f"Starting {name} agent")
234
+ result = agent.invoke(state)
235
+ logger.info(f"{name} agent completed with result: {result}")
236
+ return {"messages": [HumanMessage(content=result["output"], name=name)]}
237
+ except Exception as e:
238
+ logger.error(f"Error in {name} agent: {str(e)}")
239
+ return {"messages": [HumanMessage(content=f"Error occurred in {name} agent: {str(e)}", name=name)]}
240
+
241
+ def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> str:
242
+ """An LLM-based router."""
243
+ options = ["FINISH"] + members
244
+ function_def = {
245
+ "name": "route",
246
+ "description": "Select the next role.",
247
+ "parameters": {
248
+ "title": "routeSchema",
249
+ "type": "object",
250
+ "properties": {
251
+ "next": {
252
+ "title": "Next",
253
+ "anyOf": [
254
+ {"enum": options},
255
+ ],
256
+ },
257
+ },
258
+ "required": ["next"],
259
+ },
260
+ }
261
+ system_prompt += "\nEnsure that you direct the workflow to completion. If no progress is being made, or if the task seems complete, choose FINISH."
262
+ prompt = ChatPromptTemplate.from_messages(
263
+ [
264
+ ("system", system_prompt),
265
+ MessagesPlaceholder(variable_name="messages"),
266
+ ("system", "Given the conversation above, who should act next? Or should we FINISH? Select one of: {options}"),
267
+ ]
268
+ ).partial(options=str(options), team_members=", ".join(members))
269
+ return (
270
+ prompt
271
+ | llm.bind_functions(functions=[function_def], function_call="route")
272
+ | JsonOutputFunctionsParser()
273
+ )
274
+
275
+ # ResearchTeam graph state
276
+ class ResearchTeamState(TypedDict):
277
+ messages: Annotated[List[BaseMessage], operator.add]
278
+ team_members: List[str]
279
+ next: str
280
+
281
+ search_agent = create_agent(
282
+ llm,
283
+ [tavily_tool],
284
+ "You are a research assistant who can search for up-to-date info using the tavily search engine.",
285
+ )
286
+ search_node = functools.partial(agent_node, agent=search_agent, name="Search")
287
+
288
+ research_agent = create_agent(
289
+ llm,
290
+ [scrape_webpages],
291
+ "You are a research assistant who can scrape specified urls for more detailed information using the scrape_webpages function.",
292
+ )
293
+ research_node = functools.partial(agent_node, agent=research_agent, name="WebScraper")
294
+
295
+ supervisor_agent = create_team_supervisor(
296
+ llm,
297
+ "You are a supervisor tasked with managing a conversation between the"
298
+ " following workers: Search, WebScraper. Given the following user request,"
299
+ " respond with the worker to act next. Each worker will perform a"
300
+ " task and respond with their results and status. When finished,"
301
+ " respond with FINISH.",
302
+ ["Search", "WebScraper"],
303
+ )
304
+
305
+ research_graph = StateGraph(ResearchTeamState)
306
+ research_graph.add_node("Search", search_node)
307
+ research_graph.add_node("WebScraper", research_node)
308
+ research_graph.add_node("supervisor", supervisor_agent)
309
+
310
+ # Define the control flow
311
+ research_graph.add_edge("Search", "supervisor")
312
+ research_graph.add_edge("WebScraper", "supervisor")
313
+ research_graph.add_conditional_edges(
314
+ "supervisor",
315
+ lambda x: x["next"],
316
+ {"Search": "Search", "WebScraper": "WebScraper", "FINISH": END},
317
+ )
318
+
319
+ research_graph.add_edge(START, "supervisor")
320
+ chain = research_graph.compile()
321
+
322
+ def enter_chain(message: str):
323
+ results = {
324
+ "messages": [HumanMessage(content=message)],
325
+ }
326
+ return results
327
+
328
+ research_chain = enter_chain | chain
329
+
330
+ # Document writing team graph state
331
+ class DocWritingState(TypedDict):
332
+ messages: Annotated[List[BaseMessage], operator.add]
333
+ team_members: str
334
+ next: str
335
+ current_files: str
336
+
337
+ def prelude(state):
338
+ written_files = []
339
+ if not WORKING_DIRECTORY.exists():
340
+ WORKING_DIRECTORY.mkdir()
341
+ try:
342
+ written_files = [
343
+ f.relative_to(WORKING_DIRECTORY) for f in WORKING_DIRECTORY.rglob("*")
344
+ ]
345
+ except Exception:
346
+ pass
347
+ if not written_files:
348
+ return {**state, "current_files": "No files written."}
349
+ return {
350
+ **state,
351
+ "current_files": "\nBelow are files your team has written to the directory:\n"
352
+ + "\n".join([f" - {f}" for f in written_files]),
353
+ }
354
+
355
+ doc_writer_agent = create_agent(
356
+ llm,
357
+ [write_document, edit_document, read_document],
358
+ "You are an expert writing a research document.\n"
359
+ "Below are files currently in your directory:\n{current_files}",
360
+ )
361
+ context_aware_doc_writer_agent = prelude | doc_writer_agent
362
+ doc_writing_node = functools.partial(
363
+ agent_node, agent=context_aware_doc_writer_agent, name="DocWriter"
364
+ )
365
+
366
+ note_taking_agent = create_agent(
367
+ llm,
368
+ [create_outline, read_document],
369
+ "You are an expert senior researcher tasked with writing a paper outline and"
370
+ " taking notes to craft a perfect paper.{current_files}",
371
+ )
372
+ context_aware_note_taking_agent = prelude | note_taking_agent
373
+ note_taking_node = functools.partial(
374
+ agent_node, agent=context_aware_note_taking_agent, name="NoteTaker"
375
+ )
376
+
377
+ chart_generating_agent = create_agent(
378
+ llm,
379
+ [read_document],
380
+ "You are a data viz expert tasked with generating charts for a research project."
381
+ "{current_files}",
382
+ )
383
+ context_aware_chart_generating_agent = prelude | chart_generating_agent
384
+ chart_generating_node = functools.partial(
385
+ agent_node, agent=context_aware_chart_generating_agent, name="ChartGenerator"
386
+ )
387
+
388
+ doc_writing_supervisor = create_team_supervisor(
389
+ llm,
390
+ "You are a supervisor tasked with managing a conversation between the"
391
+ " following workers: {team_members}. Given the following user request,"
392
+ " respond with the worker to act next. Each worker will perform a"
393
+ " task and respond with their results and status. When finished,"
394
+ " respond with FINISH.",
395
+ ["DocWriter", "NoteTaker", "ChartGenerator"],
396
+ )
397
+
398
+ authoring_graph = StateGraph(DocWritingState)
399
+ authoring_graph.add_node("DocWriter", doc_writing_node)
400
+ authoring_graph.add_node("NoteTaker", note_taking_node)
401
+ authoring_graph.add_node("ChartGenerator", chart_generating_node)
402
+ authoring_graph.add_node("supervisor", doc_writing_supervisor)
403
+
404
+ authoring_graph.add_edge("DocWriter", "supervisor")
405
+ authoring_graph.add_edge("NoteTaker", "supervisor")
406
+ authoring_graph.add_edge("ChartGenerator", "supervisor")
407
+ authoring_graph.add_conditional_edges(
408
+ "supervisor",
409
+ lambda x: x["next"],
410
+ {
411
+ "DocWriter": "DocWriter",
412
+ "NoteTaker": "NoteTaker",
413
+ "ChartGenerator": "ChartGenerator",
414
+ "FINISH": END,
415
+ },
416
+ )
417
+
418
+ authoring_graph.add_edge(START, "supervisor")
419
+ chain = authoring_graph.compile()
420
+
421
+ def enter_chain(message: str, members: List[str]):
422
+ results = {
423
+ "messages": [HumanMessage(content=message)],
424
+ "team_members": ", ".join(members),
425
+ }
426
+ return results
427
+
428
+ authoring_chain = (
429
+ functools.partial(enter_chain, members=authoring_graph.nodes)
430
+ | authoring_graph.compile()
431
+ )
432
+
433
+ supervisor_node = create_team_supervisor(
434
+ llm,
435
+ "You are a supervisor tasked with managing a conversation between the"
436
+ " following teams: {team_members}. Given the following user request,"
437
+ " respond with the worker to act next. Each worker will perform a"
438
+ " task and respond with their results and status. Make sure each team is used atleast once. When finished,"
439
+ " respond with FINISH.",
440
+ ["ResearchTeam", "PaperWritingTeam"],
441
+ )
442
+
443
+ class State(TypedDict):
444
+ messages: Annotated[List[BaseMessage], operator.add]
445
+ next: str
446
+
447
+ def get_last_message(state: State) -> str:
448
+ return state["messages"][-1].content
449
+
450
+ def join_graph(response: dict):
451
+ return {"messages": [response["messages"][-1]]}
452
+
453
+ super_graph = StateGraph(State)
454
+ super_graph.add_node("ResearchTeam", get_last_message | research_chain | join_graph)
455
+ super_graph.add_node("PaperWritingTeam", get_last_message | authoring_chain | join_graph)
456
+ super_graph.add_node("supervisor", supervisor_node)
457
+
458
+ super_graph.add_edge("ResearchTeam", "supervisor")
459
+ super_graph.add_edge("PaperWritingTeam", "supervisor")
460
+ super_graph.add_conditional_edges(
461
+ "supervisor",
462
+ lambda x: x["next"],
463
+ {
464
+ "PaperWritingTeam": "PaperWritingTeam",
465
+ "ResearchTeam": "ResearchTeam",
466
+ "FINISH": END,
467
+ },
468
+ )
469
+ super_graph.add_edge(START, "supervisor")
470
+ super_graph = super_graph.compile()
471
+
472
+ # Streamlit UI
473
+ input_text = st.text_input("Enter your query:")
474
+
475
+ if input_text:
476
+ if not openai_api_key or not tavily_api_key:
477
+ st.error("Please provide both OpenAI and Tavily API keys in the sidebar.")
478
+ else:
479
+ st.markdown("### 🛠️ Task Progress")
480
+ start_time = time.time()
481
+ max_execution_time = 300 # 5 minutes
482
+
483
+ try:
484
+ for s in super_graph.stream(
485
+ {
486
+ "messages": [
487
+ HumanMessage(
488
+ content=input_text
489
+ )
490
+ ],
491
+ },
492
+ {"recursion_limit": 300}, # Increased recursion limit
493
+ ):
494
+ if "__end__" not in s:
495
+ st.write(s)
496
+ st.write("---")
497
+
498
+ # Check for timeout
499
+ if time.time() - start_time > max_execution_time:
500
+ st.warning("Execution time exceeded. Terminating the process.")
501
+ break
502
+ except Exception as e:
503
+ st.error(f"An error occurred: {str(e)}")
504
+ logger.error(f"Error in super_graph execution: {str(e)}")
505
+
506
+ if st.button("List Output Files"):
507
+ files = os.listdir(WORKING_DIRECTORY)
508
+ if files:
509
+ st.write("### 📂 Files in working directory:")
510
+ for file in files:
511
+ st.write(f"📄 {file}")
512
+ else:
513
+ st.write("No files found in the working directory.")
514
+
515
+ output_files = os.listdir(WORKING_DIRECTORY)
516
+ if output_files:
517
+ output_file = st.selectbox("Select an output file to download:", output_files)
518
+
519
+ if st.button("Download Output Document"):
520
+ file_path = WORKING_DIRECTORY / output_file
521
+ if file_path.exists():
522
+ with file_path.open("rb") as file:
523
+ st.download_button(
524
+ label="📥 Download Output Document",
525
+ data=file,
526
+ file_name=output_file,
527
+ )
528
+ else:
529
+ st.write("Output document not found.")
530
+ else:
531
+ st.write("No output files available for download.")
532
+
533
+ # Cleanup
534
+ if st.button("Clear Working Directory"):
535
+ for file in WORKING_DIRECTORY.iterdir():
536
+ if file.is_file():
537
+ file.unlink()
538
+ st.success("Working directory cleared.")