swayam-the-coder commited on
Commit
3d68068
·
verified ·
1 Parent(s): a41cc21

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -532
app.py DELETED
@@ -1,532 +0,0 @@
1
- import os
2
- import streamlit as st
3
- from pathlib import Path
4
- from tempfile import TemporaryDirectory
5
- from langchain_core.messages import BaseMessage, HumanMessage
6
- from typing import Annotated, List, Optional, Dict
7
- from typing_extensions import TypedDict
8
- from langchain_community.document_loaders import WebBaseLoader
9
- from langchain_community.tools.tavily_search import TavilySearchResults
10
- from langchain_core.tools import tool
11
- from langchain.agents import AgentExecutor, create_openai_functions_agent
12
- from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
13
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
14
- from langchain_openai import ChatOpenAI
15
- from langgraph.graph import END, StateGraph, START
16
- import functools
17
- import operator
18
- import logging
19
- import time
20
- from tenacity import retry, stop_after_attempt, wait_exponential
21
-
22
- # Set up logging
23
- logging.basicConfig(level=logging.INFO)
24
- logger = logging.getLogger(__name__)
25
-
26
- # Streamlit UI
27
- st.set_page_config(page_title="MARS: Multi-Agent Report Synthesizer", layout="wide")
28
-
29
- # Custom CSS for styling
30
- st.markdown("""
31
- <style>
32
- body {
33
- background-color: #f5f5f5;
34
- color: #333333;
35
- font-family: 'Comic Sans MS', 'Comic Sans', cursive;
36
- }
37
- .report-container {
38
- border-radius: 10px;
39
- background-color: #ffcccb;
40
- padding: 20px;
41
- }
42
- .sidebar .sidebar-content {
43
- background-color: #333333;
44
- color: #ffffff;
45
- }
46
- .stButton button {
47
- background-color: #ff6347;
48
- color: #ffffff;
49
- border-radius: 5px;
50
- font-size: 18px;
51
- padding: 10px 20px;
52
- font-weight: bold;
53
- }
54
- .stTextInput input {
55
- border-radius: 5px;
56
- border: 2px solid #ff6347;
57
- font-size: 16px;
58
- padding: 10px;
59
- width: 100%;
60
- }
61
- .stTextInput label {
62
- font-size: 18px;
63
- font-weight: bold;
64
- color: #333333;
65
- }
66
- .stSelectbox label, .stDownloadButton label {
67
- font-size: 18px;
68
- font-weight: bold;
69
- color: #333333;
70
- }
71
- .stSelectbox div, .stDownloadButton div {
72
- background-color: #ffcccb;
73
- color: #333333;
74
- border-radius: 5px;
75
- padding: 10px;
76
- font-size: 16px;
77
- }
78
- </style>
79
- """, unsafe_allow_html=True)
80
-
81
- st.title("🚀 MARS: Multi-agent Report Synthesizer 🤖")
82
-
83
- # Sidebar for API key inputs
84
- st.sidebar.title("API Keys")
85
- openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password")
86
- tavily_api_key = st.sidebar.text_input("Tavily API Key", type="password")
87
-
88
- # Use user-input API keys or fall back to environment variables
89
- os.environ["OPENAI_API_KEY"] = openai_api_key or os.getenv("OPENAI_API_KEY")
90
- os.environ["TAVILY_API_KEY"] = tavily_api_key or os.getenv("TAVILY_API_KEY")
91
-
92
- st.sidebar.title("📋 Instructions")
93
- st.sidebar.write("""
94
- 1. Enter your OpenAI and Tavily API keys in the fields above.
95
- 2. Enter your query in the input box.
96
- 3. Marvin AI will assign tasks to different teams.
97
- 4. You can see the progress and download the final report.
98
- 5. Use the buttons to list and download output files.
99
- """)
100
-
101
- # Set environment variables
102
- os.environ["LANGCHAIN_TRACING_V2"] = "true"
103
-
104
- # Initialize temporary directory
105
- if 'working_directory' not in st.session_state:
106
- _TEMP_DIRECTORY = TemporaryDirectory()
107
- st.session_state.working_directory = Path(_TEMP_DIRECTORY.name)
108
-
109
- WORKING_DIRECTORY = st.session_state.working_directory
110
-
111
- # Define tools
112
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
113
- def tavily_search_with_retry(*args, **kwargs):
114
- return TavilySearchResults(*args, **kwargs)
115
-
116
- tavily_tool = tavily_search_with_retry(max_results=5, api_key=tavily_api_key or os.getenv("TAVILY_API_KEY"))
117
-
118
- @tool
119
- def scrape_webpages(urls: List[str]) -> str:
120
- """Use requests and bs4 to scrape the provided web pages for detailed information."""
121
- try:
122
- loader = WebBaseLoader(urls)
123
- docs = loader.load()
124
- return "\n\n".join(
125
- [
126
- f'\n{doc.page_content}\n'
127
- for doc in docs
128
- ]
129
- )
130
- except Exception as e:
131
- logger.error(f"Error in scrape_webpages: {str(e)}")
132
- return f"Error occurred while scraping webpages: {str(e)}"
133
-
134
- @tool
135
- def create_outline(
136
- points: Annotated[List[str], "List of main points or sections."],
137
- file_name: Annotated[str, "File path to save the outline."],
138
- ) -> Annotated[str, "Path of the saved outline file."]:
139
- """Create and save an outline."""
140
- try:
141
- with (WORKING_DIRECTORY / file_name).open("w") as file:
142
- for i, point in enumerate(points):
143
- file.write(f"{i + 1}. {point}\n")
144
- return f"Outline saved to {file_name}"
145
- except Exception as e:
146
- logger.error(f"Error in create_outline: {str(e)}")
147
- return f"Error occurred while creating outline: {str(e)}"
148
-
149
- @tool
150
- def read_document(
151
- file_name: Annotated[str, "File path to save the document."],
152
- start: Annotated[Optional[int], "The start line. Default is 0"] = None,
153
- end: Annotated[Optional[int], "The end line. Default is None"] = None,
154
- ) -> str:
155
- """Read the specified document."""
156
- try:
157
- with (WORKING_DIRECTORY / file_name).open("r") as file:
158
- lines = file.readlines()
159
- if start is not None:
160
- start = 0
161
- return "\n".join(lines[start:end])
162
- except Exception as e:
163
- logger.error(f"Error in read_document: {str(e)}")
164
- return f"Error occurred while reading document: {str(e)}"
165
-
166
- @tool
167
- def write_document(
168
- content: Annotated[str, "Text content to be written into the document."],
169
- file_name: Annotated[str, "File path to save the document."],
170
- ) -> Annotated[str, "Path of the saved document file."]:
171
- """Create and save a text document."""
172
- try:
173
- with (WORKING_DIRECTORY / file_name).open("w") as file:
174
- file.write(content)
175
- return f"Document saved to {file_name}"
176
- except Exception as e:
177
- logger.error(f"Error in write_document: {str(e)}")
178
- return f"Error occurred while writing document: {str(e)}"
179
-
180
- @tool
181
- def edit_document(
182
- file_name: Annotated[str, "Path of the document to be edited."],
183
- inserts: Annotated[
184
- Dict[int, str],
185
- "Dictionary where key is the line number (1-indexed) and value is the text to be inserted at that line.",
186
- ],
187
- ) -> Annotated[str, "Path of the edited document file."]:
188
- """Edit a document by inserting text at specific line numbers."""
189
- try:
190
- with (WORKING_DIRECTORY / file_name).open("r") as file:
191
- lines = file.readlines()
192
- sorted_inserts = sorted(inserts.items())
193
- for line_number, text in sorted_inserts:
194
- if 1 <= line_number <= len(lines) + 1:
195
- lines.insert(line_number - 1, text + "\n")
196
- else:
197
- return f"Error: Line number {line_number} is out of range."
198
- with (WORKING_DIRECTORY / file_name).open("w") as file:
199
- file.writelines(lines)
200
- return f"Document edited and saved to {file_name}"
201
- except Exception as e:
202
- logger.error(f"Error in edit_document: {str(e)}")
203
- return f"Error occurred while editing document: {str(e)}"
204
-
205
- # Define the agents and their tools
206
- llm = ChatOpenAI(model="gpt-3.5-turbo-0125", openai_api_key=openai_api_key or os.getenv("OPENAI_API_KEY"))
207
-
208
- def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str) -> str:
209
- """Create a function-calling agent and add it to the graph."""
210
- system_prompt += """\nWork autonomously according to your specialty, using the tools available to you.
211
- Do not ask for clarification.
212
- Your other team members (and other teams) will collaborate with you with their own specialties.
213
- You are chosen for a reason! You are one of the following team members: {team_members}."""
214
- prompt = ChatPromptTemplate.from_messages(
215
- [
216
- ("system", system_prompt),
217
- MessagesPlaceholder(variable_name="messages"),
218
- MessagesPlaceholder(variable_name="agent_scratchpad"),
219
- ]
220
- )
221
- agent = create_openai_functions_agent(llm, tools, prompt)
222
- executor = AgentExecutor(agent=agent, tools=tools)
223
- return executor
224
-
225
- def agent_node(state, agent, name):
226
- try:
227
- logger.info(f"Starting {name} agent")
228
- result = agent.invoke(state)
229
- logger.info(f"{name} agent completed with result: {result}")
230
- return {"messages": [HumanMessage(content=result["output"], name=name)]}
231
- except Exception as e:
232
- logger.error(f"Error in {name} agent: {str(e)}")
233
- return {"messages": [HumanMessage(content=f"Error occurred in {name} agent: {str(e)}", name=name)]}
234
-
235
- def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> str:
236
- """An LLM-based router."""
237
- options = ["FINISH"] + members
238
- function_def = {
239
- "name": "route",
240
- "description": "Select the next role.",
241
- "parameters": {
242
- "title": "routeSchema",
243
- "type": "object",
244
- "properties": {
245
- "next": {
246
- "title": "Next",
247
- "anyOf": [
248
- {"enum": options},
249
- ],
250
- },
251
- },
252
- "required": ["next"],
253
- },
254
- }
255
- system_prompt += "\nEnsure that you direct the workflow to completion. If no progress is being made, or if the task seems complete, choose FINISH."
256
- prompt = ChatPromptTemplate.from_messages(
257
- [
258
- ("system", system_prompt),
259
- MessagesPlaceholder(variable_name="messages"),
260
- ("system", "Given the conversation above, who should act next? Or should we FINISH? Select one of: {options}"),
261
- ]
262
- ).partial(options=str(options), team_members=", ".join(members))
263
- return (
264
- prompt
265
- | llm.bind_functions(functions=[function_def], function_call="route")
266
- | JsonOutputFunctionsParser()
267
- )
268
-
269
- # ResearchTeam graph state
270
- class ResearchTeamState(TypedDict):
271
- messages: Annotated[List[BaseMessage], operator.add]
272
- team_members: List[str]
273
- next: str
274
-
275
- search_agent = create_agent(
276
- llm,
277
- [tavily_tool],
278
- "You are a research assistant who can search for up-to-date info using the tavily search engine.",
279
- )
280
- search_node = functools.partial(agent_node, agent=search_agent, name="Search")
281
-
282
- research_agent = create_agent(
283
- llm,
284
- [scrape_webpages],
285
- "You are a research assistant who can scrape specified urls for more detailed information using the scrape_webpages function.",
286
- )
287
- research_node = functools.partial(agent_node, agent=research_agent, name="WebScraper")
288
-
289
- supervisor_agent = create_team_supervisor(
290
- llm,
291
- "You are a supervisor tasked with managing a conversation between the"
292
- " following workers: Search, WebScraper. Given the following user request,"
293
- " respond with the worker to act next. Each worker will perform a"
294
- " task and respond with their results and status. When finished,"
295
- " respond with FINISH.",
296
- ["Search", "WebScraper"],
297
- )
298
-
299
- research_graph = StateGraph(ResearchTeamState)
300
- research_graph.add_node("Search", search_node)
301
- research_graph.add_node("WebScraper", research_node)
302
- research_graph.add_node("supervisor", supervisor_agent)
303
-
304
- # Define the control flow
305
- research_graph.add_edge("Search", "supervisor")
306
- research_graph.add_edge("WebScraper", "supervisor")
307
- research_graph.add_conditional_edges(
308
- "supervisor",
309
- lambda x: x["next"],
310
- {"Search": "Search", "WebScraper": "WebScraper", "FINISH": END},
311
- )
312
-
313
- research_graph.add_edge(START, "supervisor")
314
- chain = research_graph.compile()
315
-
316
- def enter_chain(message: str):
317
- results = {
318
- "messages": [HumanMessage(content=message)],
319
- }
320
- return results
321
-
322
- research_chain = enter_chain | chain
323
-
324
- # Document writing team graph state
325
- class DocWritingState(TypedDict):
326
- messages: Annotated[List[BaseMessage], operator.add]
327
- team_members: str
328
- next: str
329
- current_files: str
330
-
331
- def prelude(state):
332
- written_files = []
333
- if not WORKING_DIRECTORY.exists():
334
- WORKING_DIRECTORY.mkdir()
335
- try:
336
- written_files = [
337
- f.relative_to(WORKING_DIRECTORY) for f in WORKING_DIRECTORY.rglob("*")
338
- ]
339
- except Exception:
340
- pass
341
- if not written_files:
342
- return {**state, "current_files": "No files written."}
343
- return {
344
- **state,
345
- "current_files": "\nBelow are files your team has written to the directory:\n"
346
- + "\n".join([f" - {f}" for f in written_files]),
347
- }
348
-
349
- doc_writer_agent = create_agent(
350
- llm,
351
- [write_document, edit_document, read_document],
352
- "You are an expert writing a research document.\n"
353
- "Below are files currently in your directory:\n{current_files}",
354
- )
355
- context_aware_doc_writer_agent = prelude | doc_writer_agent
356
- doc_writing_node = functools.partial(
357
- agent_node, agent=context_aware_doc_writer_agent, name="DocWriter"
358
- )
359
-
360
- note_taking_agent = create_agent(
361
- llm,
362
- [create_outline, read_document],
363
- "You are an expert senior researcher tasked with writing a paper outline and"
364
- " taking notes to craft a perfect paper.{current_files}",
365
- )
366
- context_aware_note_taking_agent = prelude | note_taking_agent
367
- note_taking_node = functools.partial(
368
- agent_node, agent=context_aware_note_taking_agent, name="NoteTaker"
369
- )
370
-
371
- chart_generating_agent = create_agent(
372
- llm,
373
- [read_document],
374
- "You are a data viz expert tasked with generating charts for a research project."
375
- "{current_files}",
376
- )
377
- context_aware_chart_generating_agent = prelude | chart_generating_agent
378
- chart_generating_node = functools.partial(
379
- agent_node, agent=context_aware_chart_generating_agent, name="ChartGenerator"
380
- )
381
-
382
- doc_writing_supervisor = create_team_supervisor(
383
- llm,
384
- "You are a supervisor tasked with managing a conversation between the"
385
- " following workers: {team_members}. Given the following user request,"
386
- " respond with the worker to act next. Each worker will perform a"
387
- " task and respond with their results and status. When finished,"
388
- " respond with FINISH.",
389
- ["DocWriter", "NoteTaker", "ChartGenerator"],
390
- )
391
-
392
- authoring_graph = StateGraph(DocWritingState)
393
- authoring_graph.add_node("DocWriter", doc_writing_node)
394
- authoring_graph.add_node("NoteTaker", note_taking_node)
395
- authoring_graph.add_node("ChartGenerator", chart_generating_node)
396
- authoring_graph.add_node("supervisor", doc_writing_supervisor)
397
-
398
- authoring_graph.add_edge("DocWriter", "supervisor")
399
- authoring_graph.add_edge("NoteTaker", "supervisor")
400
- authoring_graph.add_edge("ChartGenerator", "supervisor")
401
- authoring_graph.add_conditional_edges(
402
- "supervisor",
403
- lambda x: x["next"],
404
- {
405
- "DocWriter": "DocWriter",
406
- "NoteTaker": "NoteTaker",
407
- "ChartGenerator": "ChartGenerator",
408
- "FINISH": END,
409
- },
410
- )
411
-
412
- authoring_graph.add_edge(START, "supervisor")
413
- chain = authoring_graph.compile()
414
-
415
- def enter_chain(message: str, members: List[str]):
416
- results = {
417
- "messages": [HumanMessage(content=message)],
418
- "team_members": ", ".join(members),
419
- }
420
- return results
421
-
422
- authoring_chain = (
423
- functools.partial(enter_chain, members=authoring_graph.nodes)
424
- | authoring_graph.compile()
425
- )
426
-
427
- supervisor_node = create_team_supervisor(
428
- llm,
429
- "You are a supervisor tasked with managing a conversation between the"
430
- " following teams: {team_members}. Given the following user request,"
431
- " respond with the worker to act next. Each worker will perform a"
432
- " task and respond with their results and status. Make sure each team is used atleast once. When finished,"
433
- " respond with FINISH.",
434
- ["ResearchTeam", "PaperWritingTeam"],
435
- )
436
-
437
- class State(TypedDict):
438
- messages: Annotated[List[BaseMessage], operator.add]
439
- next: str
440
-
441
- def get_last_message(state: State) -> str:
442
- return state["messages"][-1].content
443
-
444
- def join_graph(response: dict):
445
- return {"messages": [response["messages"][-1]]}
446
-
447
- super_graph = StateGraph(State)
448
- super_graph.add_node("ResearchTeam", get_last_message | research_chain | join_graph)
449
- super_graph.add_node("PaperWritingTeam", get_last_message | authoring_chain | join_graph)
450
- super_graph.add_node("supervisor", supervisor_node)
451
-
452
- super_graph.add_edge("ResearchTeam", "supervisor")
453
- super_graph.add_edge("PaperWritingTeam", "supervisor")
454
- super_graph.add_conditional_edges(
455
- "supervisor",
456
- lambda x: x["next"],
457
- {
458
- "PaperWritingTeam": "PaperWritingTeam",
459
- "ResearchTeam": "ResearchTeam",
460
- "FINISH": END,
461
- },
462
- )
463
- super_graph.add_edge(START, "supervisor")
464
- super_graph = super_graph.compile()
465
-
466
- # Streamlit UI
467
- input_text = st.text_input("Enter your query:")
468
-
469
- if input_text:
470
- if not (openai_api_key and tavily_api_key):
471
- st.error("Please provide both OpenAI and Tavily API keys in the sidebar.")
472
- else:
473
- st.markdown("### 🛠️ Task Progress")
474
- start_time = time.time()
475
- max_execution_time = 300 # 5 minutes
476
-
477
- try:
478
- for s in super_graph.stream(
479
- {
480
- "messages": [
481
- HumanMessage(
482
- content=input_text
483
- )
484
- ],
485
- },
486
- {"recursion_limit": 300}, # Increased recursion limit
487
- ):
488
- if "__end__" not in s:
489
- st.write(s)
490
- st.write("---")
491
-
492
- # Check for timeout
493
- if time.time() - start_time > max_execution_time:
494
- st.warning("Execution time exceeded. Terminating the process.")
495
- break
496
- except Exception as e:
497
- st.error(f"An error occurred: {str(e)}")
498
- logger.error(f"Error in super_graph execution: {str(e)}")
499
-
500
- if st.button("List Output Files"):
501
- files = os.listdir(WORKING_DIRECTORY)
502
- if files:
503
- st.write("### 📂 Files in working directory:")
504
- for file in files:
505
- st.write(f"📄 {file}")
506
- else:
507
- st.write("No files found in the working directory.")
508
-
509
- output_files = os.listdir(WORKING_DIRECTORY)
510
- if output_files:
511
- output_file = st.selectbox("Select an output file to download:", output_files)
512
-
513
- if st.button("Download Output Document"):
514
- file_path = WORKING_DIRECTORY / output_file
515
- if file_path.exists():
516
- with file_path.open("rb") as file:
517
- st.download_button(
518
- label="📥 Download Output Document",
519
- data=file,
520
- file_name=output_file,
521
- )
522
- else:
523
- st.write("Output document not found.")
524
- else:
525
- st.write("No output files available for download.")
526
-
527
- # Cleanup
528
- if st.button("Clear Working Directory"):
529
- for file in WORKING_DIRECTORY.iterdir():
530
- if file.is_file():
531
- file.unlink()
532
- st.success("Working directory cleared.")