swayam-the-coder commited on
Commit
5e4ac43
·
verified ·
1 Parent(s): c6c55e6

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -538
app.py DELETED
@@ -1,538 +0,0 @@
1
- import os
2
- import streamlit as st
3
- from pathlib import Path
4
- from tempfile import TemporaryDirectory
5
- from langchain_core.messages import BaseMessage, HumanMessage
6
- from typing import Annotated, List, Optional, Dict
7
- from typing_extensions import TypedDict
8
- from langchain_community.document_loaders import WebBaseLoader
9
- from langchain_community.tools.tavily_search import TavilySearchResults
10
- from langchain_core.tools import tool
11
- from langchain.agents import AgentExecutor, create_openai_functions_agent
12
- from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
13
- from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
14
- from langchain_openai import ChatOpenAI
15
- from langgraph.graph import END, StateGraph, START
16
- import functools
17
- import operator
18
- import logging
19
- import time
20
- from tenacity import retry, stop_after_attempt, wait_exponential
21
-
22
- # Set up logging
23
- logging.basicConfig(level=logging.INFO)
24
- logger = logging.getLogger(__name__)
25
-
26
- # Streamlit UI
27
- st.set_page_config(page_title="MARS: Multi-Agent Report Synthesizer", layout="wide")
28
-
29
- # Custom CSS for styling
30
- st.markdown("""
31
- <style>
32
- body {
33
- background-color: #f5f5f5;
34
- color: #333333;
35
- font-family: 'Comic Sans MS', 'Comic Sans', cursive;
36
- }
37
- .report-container {
38
- border-radius: 10px;
39
- background-color: #ffcccb;
40
- padding: 20px;
41
- }
42
- .sidebar .sidebar-content {
43
- background-color: #333333;
44
- color: #ffffff;
45
- }
46
- .stButton button {
47
- background-color: #ff6347;
48
- color: #ffffff;
49
- border-radius: 5px;
50
- font-size: 18px;
51
- padding: 10px 20px;
52
- font-weight: bold;
53
- }
54
- .stTextInput input {
55
- border-radius: 5px;
56
- border: 2px solid #ff6347;
57
- font-size: 16px;
58
- padding: 10px;
59
- width: 100%;
60
- }
61
- .stTextInput label {
62
- font-size: 18px;
63
- font-weight: bold;
64
- color: #333333;
65
- }
66
- .stSelectbox label, .stDownloadButton label {
67
- font-size: 18px;
68
- font-weight: bold;
69
- color: #333333;
70
- }
71
- .stSelectbox div, .stDownloadButton div {
72
- background-color: #ffcccb;
73
- color: #333333;
74
- border-radius: 5px;
75
- padding: 10px;
76
- font-size: 16px;
77
- }
78
- </style>
79
- """, unsafe_allow_html=True)
80
-
81
- st.title("🚀 MARS: Multi-agent Report Synthesizer 🤖")
82
-
83
- # Sidebar for API key inputs
84
- st.sidebar.title("API Keys")
85
- openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password")
86
- tavily_api_key = st.sidebar.text_input("Tavily API Key", type="password")
87
-
88
- # Only set environment variables if API keys are provided
89
- if openai_api_key:
90
- os.environ["OPENAI_API_KEY"] = openai_api_key
91
- if tavily_api_key:
92
- os.environ["TAVILY_API_KEY"] = tavily_api_key
93
-
94
- # Use environment variables as fallback
95
- openai_api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
96
- tavily_api_key = tavily_api_key or os.getenv("TAVILY_API_KEY")
97
-
98
- st.sidebar.title("📋 Instructions")
99
- st.sidebar.write("""
100
- 1. Enter your OpenAI and Tavily API keys in the fields above.
101
- 2. Enter your query in the input box.
102
- 3. Marvin AI will assign tasks to different teams.
103
- 4. You can see the progress and download the final report.
104
- 5. Use the buttons to list and download output files.
105
- """)
106
-
107
- # Set environment variables
108
- os.environ["LANGCHAIN_TRACING_V2"] = "true"
109
-
110
- # Initialize temporary directory
111
- if 'working_directory' not in st.session_state:
112
- _TEMP_DIRECTORY = TemporaryDirectory()
113
- st.session_state.working_directory = Path(_TEMP_DIRECTORY.name)
114
-
115
- WORKING_DIRECTORY = st.session_state.working_directory
116
-
117
- # Define tools
118
- @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=10))
119
- def tavily_search_with_retry(*args, **kwargs):
120
- return TavilySearchResults(*args, **kwargs)
121
-
122
- tavily_tool = tavily_search_with_retry(max_results=5, api_key=tavily_api_key)
123
-
124
- @tool
125
- def scrape_webpages(urls: List[str]) -> str:
126
- """Use requests and bs4 to scrape the provided web pages for detailed information."""
127
- try:
128
- loader = WebBaseLoader(urls)
129
- docs = loader.load()
130
- return "\n\n".join(
131
- [
132
- f'\n{doc.page_content}\n'
133
- for doc in docs
134
- ]
135
- )
136
- except Exception as e:
137
- logger.error(f"Error in scrape_webpages: {str(e)}")
138
- return f"Error occurred while scraping webpages: {str(e)}"
139
-
140
- @tool
141
- def create_outline(
142
- points: Annotated[List[str], "List of main points or sections."],
143
- file_name: Annotated[str, "File path to save the outline."],
144
- ) -> Annotated[str, "Path of the saved outline file."]:
145
- """Create and save an outline."""
146
- try:
147
- with (WORKING_DIRECTORY / file_name).open("w") as file:
148
- for i, point in enumerate(points):
149
- file.write(f"{i + 1}. {point}\n")
150
- return f"Outline saved to {file_name}"
151
- except Exception as e:
152
- logger.error(f"Error in create_outline: {str(e)}")
153
- return f"Error occurred while creating outline: {str(e)}"
154
-
155
- @tool
156
- def read_document(
157
- file_name: Annotated[str, "File path to save the document."],
158
- start: Annotated[Optional[int], "The start line. Default is 0"] = None,
159
- end: Annotated[Optional[int], "The end line. Default is None"] = None,
160
- ) -> str:
161
- """Read the specified document."""
162
- try:
163
- with (WORKING_DIRECTORY / file_name).open("r") as file:
164
- lines = file.readlines()
165
- if start is not None:
166
- start = 0
167
- return "\n".join(lines[start:end])
168
- except Exception as e:
169
- logger.error(f"Error in read_document: {str(e)}")
170
- return f"Error occurred while reading document: {str(e)}"
171
-
172
- @tool
173
- def write_document(
174
- content: Annotated[str, "Text content to be written into the document."],
175
- file_name: Annotated[str, "File path to save the document."],
176
- ) -> Annotated[str, "Path of the saved document file."]:
177
- """Create and save a text document."""
178
- try:
179
- with (WORKING_DIRECTORY / file_name).open("w") as file:
180
- file.write(content)
181
- return f"Document saved to {file_name}"
182
- except Exception as e:
183
- logger.error(f"Error in write_document: {str(e)}")
184
- return f"Error occurred while writing document: {str(e)}"
185
-
186
- @tool
187
- def edit_document(
188
- file_name: Annotated[str, "Path of the document to be edited."],
189
- inserts: Annotated[
190
- Dict[int, str],
191
- "Dictionary where key is the line number (1-indexed) and value is the text to be inserted at that line.",
192
- ],
193
- ) -> Annotated[str, "Path of the edited document file."]:
194
- """Edit a document by inserting text at specific line numbers."""
195
- try:
196
- with (WORKING_DIRECTORY / file_name).open("r") as file:
197
- lines = file.readlines()
198
- sorted_inserts = sorted(inserts.items())
199
- for line_number, text in sorted_inserts:
200
- if 1 <= line_number <= len(lines) + 1:
201
- lines.insert(line_number - 1, text + "\n")
202
- else:
203
- return f"Error: Line number {line_number} is out of range."
204
- with (WORKING_DIRECTORY / file_name).open("w") as file:
205
- file.writelines(lines)
206
- return f"Document edited and saved to {file_name}"
207
- except Exception as e:
208
- logger.error(f"Error in edit_document: {str(e)}")
209
- return f"Error occurred while editing document: {str(e)}"
210
-
211
- # Define the agents and their tools
212
- llm = ChatOpenAI(model="gpt-3.5-turbo-0125", api_key=openai_api_key)
213
-
214
- def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str) -> str:
215
- """Create a function-calling agent and add it to the graph."""
216
- system_prompt += """\nWork autonomously according to your specialty, using the tools available to you.
217
- Do not ask for clarification.
218
- Your other team members (and other teams) will collaborate with you with their own specialties.
219
- You are chosen for a reason! You are one of the following team members: {team_members}."""
220
- prompt = ChatPromptTemplate.from_messages(
221
- [
222
- ("system", system_prompt),
223
- MessagesPlaceholder(variable_name="messages"),
224
- MessagesPlaceholder(variable_name="agent_scratchpad"),
225
- ]
226
- )
227
- agent = create_openai_functions_agent(llm, tools, prompt)
228
- executor = AgentExecutor(agent=agent, tools=tools)
229
- return executor
230
-
231
- def agent_node(state, agent, name):
232
- try:
233
- logger.info(f"Starting {name} agent")
234
- result = agent.invoke(state)
235
- logger.info(f"{name} agent completed with result: {result}")
236
- return {"messages": [HumanMessage(content=result["output"], name=name)]}
237
- except Exception as e:
238
- logger.error(f"Error in {name} agent: {str(e)}")
239
- return {"messages": [HumanMessage(content=f"Error occurred in {name} agent: {str(e)}", name=name)]}
240
-
241
- def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> str:
242
- """An LLM-based router."""
243
- options = ["FINISH"] + members
244
- function_def = {
245
- "name": "route",
246
- "description": "Select the next role.",
247
- "parameters": {
248
- "title": "routeSchema",
249
- "type": "object",
250
- "properties": {
251
- "next": {
252
- "title": "Next",
253
- "anyOf": [
254
- {"enum": options},
255
- ],
256
- },
257
- },
258
- "required": ["next"],
259
- },
260
- }
261
- system_prompt += "\nEnsure that you direct the workflow to completion. If no progress is being made, or if the task seems complete, choose FINISH."
262
- prompt = ChatPromptTemplate.from_messages(
263
- [
264
- ("system", system_prompt),
265
- MessagesPlaceholder(variable_name="messages"),
266
- ("system", "Given the conversation above, who should act next? Or should we FINISH? Select one of: {options}"),
267
- ]
268
- ).partial(options=str(options), team_members=", ".join(members))
269
- return (
270
- prompt
271
- | llm.bind_functions(functions=[function_def], function_call="route")
272
- | JsonOutputFunctionsParser()
273
- )
274
-
275
- # ResearchTeam graph state
276
- class ResearchTeamState(TypedDict):
277
- messages: Annotated[List[BaseMessage], operator.add]
278
- team_members: List[str]
279
- next: str
280
-
281
- search_agent = create_agent(
282
- llm,
283
- [tavily_tool],
284
- "You are a research assistant who can search for up-to-date info using the tavily search engine.",
285
- )
286
- search_node = functools.partial(agent_node, agent=search_agent, name="Search")
287
-
288
- research_agent = create_agent(
289
- llm,
290
- [scrape_webpages],
291
- "You are a research assistant who can scrape specified urls for more detailed information using the scrape_webpages function.",
292
- )
293
- research_node = functools.partial(agent_node, agent=research_agent, name="WebScraper")
294
-
295
- supervisor_agent = create_team_supervisor(
296
- llm,
297
- "You are a supervisor tasked with managing a conversation between the"
298
- " following workers: Search, WebScraper. Given the following user request,"
299
- " respond with the worker to act next. Each worker will perform a"
300
- " task and respond with their results and status. When finished,"
301
- " respond with FINISH.",
302
- ["Search", "WebScraper"],
303
- )
304
-
305
- research_graph = StateGraph(ResearchTeamState)
306
- research_graph.add_node("Search", search_node)
307
- research_graph.add_node("WebScraper", research_node)
308
- research_graph.add_node("supervisor", supervisor_agent)
309
-
310
- # Define the control flow
311
- research_graph.add_edge("Search", "supervisor")
312
- research_graph.add_edge("WebScraper", "supervisor")
313
- research_graph.add_conditional_edges(
314
- "supervisor",
315
- lambda x: x["next"],
316
- {"Search": "Search", "WebScraper": "WebScraper", "FINISH": END},
317
- )
318
-
319
- research_graph.add_edge(START, "supervisor")
320
- chain = research_graph.compile()
321
-
322
- def enter_chain(message: str):
323
- results = {
324
- "messages": [HumanMessage(content=message)],
325
- }
326
- return results
327
-
328
- research_chain = enter_chain | chain
329
-
330
- # Document writing team graph state
331
- class DocWritingState(TypedDict):
332
- messages: Annotated[List[BaseMessage], operator.add]
333
- team_members: str
334
- next: str
335
- current_files: str
336
-
337
- def prelude(state):
338
- written_files = []
339
- if not WORKING_DIRECTORY.exists():
340
- WORKING_DIRECTORY.mkdir()
341
- try:
342
- written_files = [
343
- f.relative_to(WORKING_DIRECTORY) for f in WORKING_DIRECTORY.rglob("*")
344
- ]
345
- except Exception:
346
- pass
347
- if not written_files:
348
- return {**state, "current_files": "No files written."}
349
- return {
350
- **state,
351
- "current_files": "\nBelow are files your team has written to the directory:\n"
352
- + "\n".join([f" - {f}" for f in written_files]),
353
- }
354
-
355
- doc_writer_agent = create_agent(
356
- llm,
357
- [write_document, edit_document, read_document],
358
- "You are an expert writing a research document.\n"
359
- "Below are files currently in your directory:\n{current_files}",
360
- )
361
- context_aware_doc_writer_agent = prelude | doc_writer_agent
362
- doc_writing_node = functools.partial(
363
- agent_node, agent=context_aware_doc_writer_agent, name="DocWriter"
364
- )
365
-
366
- note_taking_agent = create_agent(
367
- llm,
368
- [create_outline, read_document],
369
- "You are an expert senior researcher tasked with writing a paper outline and"
370
- " taking notes to craft a perfect paper.{current_files}",
371
- )
372
- context_aware_note_taking_agent = prelude | note_taking_agent
373
- note_taking_node = functools.partial(
374
- agent_node, agent=context_aware_note_taking_agent, name="NoteTaker"
375
- )
376
-
377
- chart_generating_agent = create_agent(
378
- llm,
379
- [read_document],
380
- "You are a data viz expert tasked with generating charts for a research project."
381
- "{current_files}",
382
- )
383
- context_aware_chart_generating_agent = prelude | chart_generating_agent
384
- chart_generating_node = functools.partial(
385
- agent_node, agent=context_aware_chart_generating_agent, name="ChartGenerator"
386
- )
387
-
388
- doc_writing_supervisor = create_team_supervisor(
389
- llm,
390
- "You are a supervisor tasked with managing a conversation between the"
391
- " following workers: {team_members}. Given the following user request,"
392
- " respond with the worker to act next. Each worker will perform a"
393
- " task and respond with their results and status. When finished,"
394
- " respond with FINISH.",
395
- ["DocWriter", "NoteTaker", "ChartGenerator"],
396
- )
397
-
398
- authoring_graph = StateGraph(DocWritingState)
399
- authoring_graph.add_node("DocWriter", doc_writing_node)
400
- authoring_graph.add_node("NoteTaker", note_taking_node)
401
- authoring_graph.add_node("ChartGenerator", chart_generating_node)
402
- authoring_graph.add_node("supervisor", doc_writing_supervisor)
403
-
404
- authoring_graph.add_edge("DocWriter", "supervisor")
405
- authoring_graph.add_edge("NoteTaker", "supervisor")
406
- authoring_graph.add_edge("ChartGenerator", "supervisor")
407
- authoring_graph.add_conditional_edges(
408
- "supervisor",
409
- lambda x: x["next"],
410
- {
411
- "DocWriter": "DocWriter",
412
- "NoteTaker": "NoteTaker",
413
- "ChartGenerator": "ChartGenerator",
414
- "FINISH": END,
415
- },
416
- )
417
-
418
- authoring_graph.add_edge(START, "supervisor")
419
- chain = authoring_graph.compile()
420
-
421
- def enter_chain(message: str, members: List[str]):
422
- results = {
423
- "messages": [HumanMessage(content=message)],
424
- "team_members": ", ".join(members),
425
- }
426
- return results
427
-
428
- authoring_chain = (
429
- functools.partial(enter_chain, members=authoring_graph.nodes)
430
- | authoring_graph.compile()
431
- )
432
-
433
- supervisor_node = create_team_supervisor(
434
- llm,
435
- "You are a supervisor tasked with managing a conversation between the"
436
- " following teams: {team_members}. Given the following user request,"
437
- " respond with the worker to act next. Each worker will perform a"
438
- " task and respond with their results and status. Make sure each team is used atleast once. When finished,"
439
- " respond with FINISH.",
440
- ["ResearchTeam", "PaperWritingTeam"],
441
- )
442
-
443
- class State(TypedDict):
444
- messages: Annotated[List[BaseMessage], operator.add]
445
- next: str
446
-
447
- def get_last_message(state: State) -> str:
448
- return state["messages"][-1].content
449
-
450
- def join_graph(response: dict):
451
- return {"messages": [response["messages"][-1]]}
452
-
453
- super_graph = StateGraph(State)
454
- super_graph.add_node("ResearchTeam", get_last_message | research_chain | join_graph)
455
- super_graph.add_node("PaperWritingTeam", get_last_message | authoring_chain | join_graph)
456
- super_graph.add_node("supervisor", supervisor_node)
457
-
458
- super_graph.add_edge("ResearchTeam", "supervisor")
459
- super_graph.add_edge("PaperWritingTeam", "supervisor")
460
- super_graph.add_conditional_edges(
461
- "supervisor",
462
- lambda x: x["next"],
463
- {
464
- "PaperWritingTeam": "PaperWritingTeam",
465
- "ResearchTeam": "ResearchTeam",
466
- "FINISH": END,
467
- },
468
- )
469
- super_graph.add_edge(START, "supervisor")
470
- super_graph = super_graph.compile()
471
-
472
- # Streamlit UI
473
- input_text = st.text_input("Enter your query:")
474
-
475
- if input_text:
476
- if not openai_api_key or not tavily_api_key:
477
- st.error("Please provide both OpenAI and Tavily API keys in the sidebar.")
478
- else:
479
- st.markdown("### 🛠️ Task Progress")
480
- start_time = time.time()
481
- max_execution_time = 300 # 5 minutes
482
-
483
- try:
484
- for s in super_graph.stream(
485
- {
486
- "messages": [
487
- HumanMessage(
488
- content=input_text
489
- )
490
- ],
491
- },
492
- {"recursion_limit": 300}, # Increased recursion limit
493
- ):
494
- if "__end__" not in s:
495
- st.write(s)
496
- st.write("---")
497
-
498
- # Check for timeout
499
- if time.time() - start_time > max_execution_time:
500
- st.warning("Execution time exceeded. Terminating the process.")
501
- break
502
- except Exception as e:
503
- st.error(f"An error occurred: {str(e)}")
504
- logger.error(f"Error in super_graph execution: {str(e)}")
505
-
506
- if st.button("List Output Files"):
507
- files = os.listdir(WORKING_DIRECTORY)
508
- if files:
509
- st.write("### 📂 Files in working directory:")
510
- for file in files:
511
- st.write(f"📄 {file}")
512
- else:
513
- st.write("No files found in the working directory.")
514
-
515
- output_files = os.listdir(WORKING_DIRECTORY)
516
- if output_files:
517
- output_file = st.selectbox("Select an output file to download:", output_files)
518
-
519
- if st.button("Download Output Document"):
520
- file_path = WORKING_DIRECTORY / output_file
521
- if file_path.exists():
522
- with file_path.open("rb") as file:
523
- st.download_button(
524
- label="📥 Download Output Document",
525
- data=file,
526
- file_name=output_file,
527
- )
528
- else:
529
- st.write("Output document not found.")
530
- else:
531
- st.write("No output files available for download.")
532
-
533
- # Cleanup
534
- if st.button("Clear Working Directory"):
535
- for file in WORKING_DIRECTORY.iterdir():
536
- if file.is_file():
537
- file.unlink()
538
- st.success("Working directory cleared.")