Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| from pathlib import Path | |
| from tempfile import TemporaryDirectory | |
| from langchain_core.messages import BaseMessage, HumanMessage | |
| from typing import Annotated, List, Optional, Dict | |
| from typing_extensions import TypedDict | |
| from langchain_community.document_loaders import WebBaseLoader | |
| from langchain_community.tools.tavily_search import TavilySearchResults | |
| from langchain_core.tools import tool | |
| from langchain.agents import AgentExecutor, create_openai_functions_agent | |
| from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.graph import END, StateGraph, START | |
| import functools | |
| import operator | |
| import logging | |
| import time | |
| from tenacity import retry, stop_after_attempt, wait_exponential, RetryError | |
| from pydantic import ValidationError | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize temporary directory | |
| if 'working_directory' not in st.session_state: | |
| _TEMP_DIRECTORY = TemporaryDirectory() | |
| st.session_state.working_directory = Path(_TEMP_DIRECTORY.name) | |
| WORKING_DIRECTORY = st.session_state.working_directory | |
| # Streamlit UI | |
| st.set_page_config(page_title="MARS: Multi-Agent Report Synthesizer", layout="wide") | |
| # Custom CSS for styling | |
| st.markdown(""" | |
| <style> | |
| body { | |
| background-color: #f5f5f5; | |
| color: #333333; | |
| font-family: 'Comic Sans MS', 'Comic Sans', cursive; | |
| } | |
| .report-container { | |
| border-radius: 10px; | |
| background-color: #ffcccb; | |
| padding: 20px; | |
| } | |
| .sidebar .sidebar-content { | |
| background-color: #333333; | |
| color: #ffffff; | |
| } | |
| .stButton button { | |
| background-color: #ff6347; | |
| color: #ffffff; | |
| border-radius: 5px; | |
| font-size: 18px; | |
| padding: 10px 20px; | |
| font-weight: bold; | |
| } | |
| .stTextInput input { | |
| border-radius: 5px; | |
| border: 2px solid #ff6347; | |
| font-size: 16px; | |
| padding: 10px; | |
| width: 100%; | |
| } | |
| .stTextInput label { | |
| font-size: 18px; | |
| font-weight: bold; | |
| color: #333333; | |
| } | |
| .stSelectbox label, .stDownloadButton label { | |
| font-size: 18px; | |
| font-weight: bold; | |
| color: #333333; | |
| } | |
| .stSelectbox div, .stDownloadButton div { | |
| background-color: #ffcccb; | |
| color: #333333; | |
| border-radius: 5px; | |
| padding: 10px; | |
| font-size: 16px; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| st.title("π MARS: Multi-agent Report Synthesizer π€") | |
| st.sidebar.title("π Instructions") | |
| st.sidebar.write(""" | |
| 1. Enter your query in the input box. | |
| 2. Marvin AI will assign tasks to different teams. | |
| 3. You can see the progress and download the final report. | |
| 4. Use the buttons to list and download output files. | |
| """) | |
| # Input fields for API keys | |
| openai_api_key = st.sidebar.text_input("OpenAI API Key", type="password") | |
| tavily_api_key = st.sidebar.text_input("Tavily API Key", type="password") | |
| # Store the API keys in the session state | |
| if openai_api_key: | |
| os.environ["OPENAI_API_KEY"] = openai_api_key | |
| if tavily_api_key: | |
| os.environ["TAVILY_API_KEY"] = tavily_api_key | |
| # Check if the API keys are set | |
| if not os.getenv("OPENAI_API_KEY"): | |
| st.error("OpenAI API Key is required.") | |
| if not os.getenv("TAVILY_API_KEY"): | |
| st.error("Tavily API Key is required.") | |
| # Define tools | |
| def tavily_search_with_retry(*args, **kwargs): | |
| try: | |
| result = TavilySearchResults(*args, **kwargs) | |
| return result | |
| except ValidationError as ve: | |
| logger.error(f"Validation error: {ve}") | |
| raise ve | |
| except Exception as e: | |
| logger.error(f"Error in Tavily search: {e}") | |
| raise e | |
| tavily_tool = tavily_search_with_retry(max_results=5) | |
| def scrape_webpages(urls: List[str]) -> str: | |
| """Use requests and bs4 to scrape the provided web pages for detailed information.""" | |
| try: | |
| loader = WebBaseLoader(urls) | |
| docs = loader.load() | |
| return "\n\n".join( | |
| [ | |
| f'\n{doc.page_content}\n' | |
| for doc in docs | |
| ] | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in scrape_webpages: {str(e)}") | |
| return f"Error occurred while scraping webpages: {str(e)}" | |
| def create_outline( | |
| points: Annotated[List[str], "List of main points or sections."], | |
| file_name: Annotated[str, "File path to save the outline."], | |
| ) -> Annotated[str, "Path of the saved outline file."]: | |
| """Create and save an outline.""" | |
| try: | |
| with (WORKING_DIRECTORY / file_name).open("w") as file: | |
| for i, point in enumerate(points): | |
| file.write(f"{i + 1}. {point}\n") | |
| return f"Outline saved to {file_name}" | |
| except Exception as e: | |
| logger.error(f"Error in create_outline: {str(e)}") | |
| return f"Error occurred while creating outline: {str(e)}" | |
| def read_document( | |
| file_name: Annotated[str, "File path to save the document."], | |
| start: Annotated[Optional[int], "The start line. Default is 0"] = None, | |
| end: Annotated[Optional[int], "The end line. Default is None"] = None, | |
| ) -> str: | |
| """Read the specified document.""" | |
| try: | |
| with (WORKING_DIRECTORY / file_name).open("r") as file: | |
| lines = file.readlines() | |
| if start is not None: | |
| start = 0 | |
| return "\n".join(lines[start:end]) | |
| except Exception as e: | |
| logger.error(f"Error in read_document: {str(e)}") | |
| return f"Error occurred while reading document: {str(e)}" | |
| def write_document( | |
| content: Annotated[str, "Text content to be written into the document."], | |
| file_name: Annotated[str, "File path to save the document."], | |
| ) -> Annotated[str, "Path of the saved document file."]: | |
| """Create and save a text document.""" | |
| try: | |
| with (WORKING_DIRECTORY / file_name).open("w") as file: | |
| file.write(content) | |
| return f"Document saved to {file_name}" | |
| except Exception as e: | |
| logger.error(f"Error in write_document: {str(e)}") | |
| return f"Error occurred while writing document: {str(e)}" | |
| def edit_document( | |
| file_name: Annotated[str, "Path of the document to be edited."], | |
| inserts: Annotated[ | |
| Dict[int, str], | |
| "Dictionary where key is the line number (1-indexed) and value is the text to be inserted at that line.", | |
| ], | |
| ) -> Annotated[str, "Path of the edited document file."]: | |
| """Edit a document by inserting text at specific line numbers.""" | |
| try: | |
| with (WORKING_DIRECTORY / file_name).open("r") as file: | |
| lines = file.readlines() | |
| sorted_inserts = sorted(inserts.items()) | |
| for line_number, text in sorted_inserts: | |
| if 1 <= line_number <= len(lines) + 1: | |
| lines.insert(line_number - 1, text + "\n") | |
| else: | |
| return f"Error: Line number {line_number} is out of range." | |
| with (WORKING_DIRECTORY / file_name).open("w") as file: | |
| file.writelines(lines) | |
| return f"Document edited and saved to {file_name}" | |
| except Exception as e: | |
| logger.error(f"Error in edit_document: {str(e)}") | |
| return f"Error occurred while editing document: {str(e)}" | |
| # Define the agents and their tools | |
| llm = ChatOpenAI(model="gpt-3.5-turbo-0125") | |
| def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str) -> str: | |
| """Create a function-calling agent and add it to the graph.""" | |
| system_prompt += """\nWork autonomously according to your specialty, using the tools available to you. | |
| Do not ask for clarification. | |
| Your other team members (and other teams) will collaborate with you with their own specialties. | |
| You are chosen for a reason! You are one of the following team members: {team_members}.""" | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| MessagesPlaceholder(variable_name="messages"), | |
| MessagesPlaceholder(variable_name="agent_scratchpad"), | |
| ] | |
| ) | |
| agent = create_openai_functions_agent(llm, tools, prompt) | |
| executor = AgentExecutor(agent=agent, tools=tools) | |
| return executor | |
| def agent_node(state, agent, name): | |
| try: | |
| logger.info(f"Starting {name} agent") | |
| result = agent.invoke(state) | |
| logger.info(f"{name} agent completed with result: {result}") | |
| return {"messages": [HumanMessage(content=result["output"], name=name)]} | |
| except ValidationError as ve: | |
| logger.error(f"Validation error in {name} agent: {ve}") | |
| return {"messages": [HumanMessage(content=f"Validation error in {name} agent: {ve}", name=name)]} | |
| except Exception as e: | |
| logger.error(f"Error in {name} agent: {e}") | |
| return {"messages": [HumanMessage(content=f"Error occurred in {name} agent: {e}", name=name)]} | |
| def create_team_supervisor(llm: ChatOpenAI, system_prompt, members) -> str: | |
| """An LLM-based router.""" | |
| options = ["FINISH"] + members | |
| function_def = { | |
| "name": "route", | |
| "description": "Select the next role.", | |
| "parameters": { | |
| "title": "routeSchema", | |
| "type": "object", | |
| "properties": { | |
| "next": { | |
| "title": "Next", | |
| "anyOf": [ | |
| {"enum": options}, | |
| ], | |
| }, | |
| }, | |
| "required": ["next"], | |
| }, | |
| } | |
| system_prompt += "\nEnsure that you direct the workflow to completion. If no progress is being made, or if the task seems complete, choose FINISH." | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", system_prompt), | |
| MessagesPlaceholder(variable_name="messages"), | |
| ("system", "Given the conversation above, who should act next? Or should we FINISH? Select one of: {options}"), | |
| ] | |
| ).partial(options=str(options), team_members=", ".join(members)) | |
| return ( | |
| prompt | |
| | llm.bind_functions(functions=[function_def], function_call="route") | |
| | JsonOutputFunctionsParser() | |
| ) | |
| # ResearchTeam graph state | |
| class ResearchTeamState(TypedDict): | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| team_members: List[str] | |
| next: str | |
| llm = ChatOpenAI(model="gpt-3.5-turbo-0125") | |
| search_agent = create_agent( | |
| llm, | |
| [tavily_tool], | |
| "You are a research assistant who can search for up-to-date info using the tavily search engine.", | |
| ) | |
| search_node = functools.partial(agent_node, agent=search_agent, name="Search") | |
| research_agent = create_agent( | |
| llm, | |
| [scrape_webpages], | |
| "You are a research assistant who can scrape specified urls for more detailed information using the scrape_webpages function.", | |
| ) | |
| research_node = functools.partial(agent_node, agent=research_agent, name="WebScraper") | |
| supervisor_agent = create_team_supervisor( | |
| llm, | |
| "You are a supervisor tasked with managing a conversation between the" | |
| " following workers: Search, WebScraper. Given the following user request," | |
| " respond with the worker to act next. Each worker will perform a" | |
| " task and respond with their results and status. When finished," | |
| " respond with FINISH.", | |
| ["Search", "WebScraper"], | |
| ) | |
| research_graph = StateGraph(ResearchTeamState) | |
| research_graph.add_node("Search", search_node) | |
| research_graph.add_node("WebScraper", research_node) | |
| research_graph.add_node("supervisor", supervisor_agent) | |
| # Define the control flow | |
| research_graph.add_edge("Search", "supervisor") | |
| research_graph.add_edge("WebScraper", "supervisor") | |
| research_graph.add_conditional_edges( | |
| "supervisor", | |
| lambda x: x["next"], | |
| {"Search": "Search", "WebScraper": "WebScraper", "FINISH": END}, | |
| ) | |
| research_graph.add_edge(START, "supervisor") | |
| chain = research_graph.compile() | |
| def enter_chain(message: str): | |
| results = { | |
| "messages": [HumanMessage(content=message)], | |
| } | |
| return results | |
| research_chain = enter_chain | chain | |
| # Document writing team graph state | |
| class DocWritingState(TypedDict): | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| team_members: str | |
| next: str | |
| current_files: str | |
| def prelude(state): | |
| written_files = [] | |
| if not WORKING_DIRECTORY.exists(): | |
| WORKING_DIRECTORY.mkdir() | |
| try: | |
| written_files = [ | |
| f.relative_to(WORKING_DIRECTORY) for f in WORKING_DIRECTORY.rglob("*") | |
| ] | |
| except Exception: | |
| pass | |
| if not written_files: | |
| return {**state, "current_files": "No files written."} | |
| return { | |
| **state, | |
| "current_files": "\nBelow are files your team has written to the directory:\n" | |
| + "\n".join([f" - {f}" for f in written_files]), | |
| } | |
| doc_writer_agent = create_agent( | |
| llm, | |
| [write_document, edit_document, read_document], | |
| "You are an expert writing a research document.\n" | |
| "Below are files currently in your directory:\n{current_files}", | |
| ) | |
| context_aware_doc_writer_agent = prelude | doc_writer_agent | |
| doc_writing_node = functools.partial( | |
| agent_node, agent=context_aware_doc_writer_agent, name="DocWriter" | |
| ) | |
| note_taking_agent = create_agent( | |
| llm, | |
| [create_outline, read_document], | |
| "You are an expert senior researcher tasked with writing a paper outline and" | |
| " taking notes to craft a perfect paper.{current_files}", | |
| ) | |
| context_aware_note_taking_agent = prelude | note_taking_agent | |
| note_taking_node = functools.partial( | |
| agent_node, agent=context_aware_note_taking_agent, name="NoteTaker" | |
| ) | |
| chart_generating_agent = create_agent( | |
| llm, | |
| [read_document], | |
| "You are a data viz expert tasked with generating charts for a research project." | |
| "{current_files}", | |
| ) | |
| context_aware_chart_generating_agent = prelude | chart_generating_agent | |
| chart_generating_node = functools.partial( | |
| agent_node, agent=context_aware_note_taking_agent, name="ChartGenerator" | |
| ) | |
| doc_writing_supervisor = create_team_supervisor( | |
| llm, | |
| "You are a supervisor tasked with managing a conversation between the" | |
| " following workers: {team_members}. Given the following user request," | |
| " respond with the worker to act next. Each worker will perform a" | |
| " task and respond with their results and status. When finished," | |
| " respond with FINISH.", | |
| ["DocWriter", "NoteTaker", "ChartGenerator"], | |
| ) | |
| authoring_graph = StateGraph(DocWritingState) | |
| authoring_graph.add_node("DocWriter", doc_writing_node) | |
| authoring_graph.add_node("NoteTaker", note_taking_node) | |
| authoring_graph.add_node("ChartGenerator", chart_generating_node) | |
| authoring_graph.add_node("supervisor", doc_writing_supervisor) | |
| authoring_graph.add_edge("DocWriter", "supervisor") | |
| authoring_graph.add_edge("NoteTaker", "supervisor") | |
| authoring_graph.add_edge("ChartGenerator", "supervisor") | |
| authoring_graph.add_conditional_edges( | |
| "supervisor", | |
| lambda x: x["next"], | |
| { | |
| "DocWriter": "DocWriter", | |
| "NoteTaker": "NoteTaker", | |
| "ChartGenerator": "ChartGenerator", | |
| "FINISH": END, | |
| }, | |
| ) | |
| authoring_graph.add_edge(START, "supervisor") | |
| chain = authoring_graph.compile() | |
| def enter_chain(message: str, members: List[str]): | |
| results = { | |
| "messages": [HumanMessage(content=message)], | |
| "team_members": ", ".join(members), | |
| } | |
| return results | |
| authoring_chain = ( | |
| functools.partial(enter_chain, members=authoring_graph.nodes) | |
| | authoring_graph.compile() | |
| ) | |
| llm = ChatOpenAI(model="gpt-3.5-turbo-0125") | |
| supervisor_node = create_team_supervisor( | |
| llm, | |
| "You are a supervisor tasked with managing a conversation between the" | |
| " following teams: {team_members}. Given the following user request," | |
| " respond with the worker to act next. Each worker will perform a" | |
| " task and respond with their results and status. Make sure each team is used atleast once. When finished," | |
| " respond with FINISH.", | |
| ["ResearchTeam", "PaperWritingTeam"], | |
| ) | |
| class State(TypedDict): | |
| messages: Annotated[List[BaseMessage], operator.add] | |
| next: str | |
| def get_last_message(state: State) -> str: | |
| return state["messages"][-1].content | |
| def join_graph(response: dict): | |
| return {"messages": [response["messages"][-1]]} | |
| super_graph = StateGraph(State) | |
| super_graph.add_node("ResearchTeam", get_last_message | research_chain | join_graph) | |
| super_graph.add_node("PaperWritingTeam", get_last_message | authoring_chain | join_graph) | |
| super_graph.add_node("supervisor", supervisor_node) | |
| super_graph.add_edge("ResearchTeam", "supervisor") | |
| super_graph.add_edge("PaperWritingTeam", "supervisor") | |
| super_graph.add_conditional_edges( | |
| "supervisor", | |
| lambda x: x["next"], | |
| { | |
| "PaperWritingTeam": "PaperWritingTeam", | |
| "ResearchTeam": "ResearchTeam", | |
| "FINISH": END, | |
| }, | |
| ) | |
| super_graph.add_edge(START, "supervisor") | |
| super_graph = super_graph.compile() | |
| input_text = st.text_input("Enter your query:") | |
| if input_text and os.getenv("OPENAI_API_KEY") and os.getenv("TAVILY_API_KEY"): | |
| st.markdown("### π οΈ Task Progress") | |
| start_time = time.time() | |
| max_execution_time = 300 # 5 minutes | |
| try: | |
| for s in super_graph.stream( | |
| { | |
| "messages": [ | |
| HumanMessage( | |
| content=input_text | |
| ) | |
| ], | |
| }, | |
| {"recursion_limit": 300}, # Increased recursion limit | |
| ): | |
| if "__end__" not in s: | |
| st.write(s) | |
| st.write("---") | |
| # Check for timeout | |
| if time.time() - start_time > max_execution_time: | |
| st.warning("Execution time exceeded. Terminating the process.") | |
| break | |
| except RetryError as re: | |
| st.error(f"Retry error occurred: {re}") | |
| logger.error(f"Retry error in super_graph execution: {re}") | |
| except ValidationError as ve: | |
| st.error(f"Validation error occurred: {ve}") | |
| logger.error(f"Validation error in super_graph execution: {ve}") | |
| except Exception as e: | |
| st.error(f"An error occurred: {str(e)}") | |
| logger.error(f"Error in super_graph execution: {str(e)}") | |
| if st.button("List Output Files"): | |
| files = os.listdir(WORKING_DIRECTORY) | |
| if files: | |
| st.write("### π Files in working directory:") | |
| for file in files: | |
| st.write(f"π {file}") | |
| else: | |
| st.write("No files found in the working directory.") | |
| output_files = os.listdir(WORKING_DIRECTORY) | |
| if output_files: | |
| output_file = st.selectbox("Select an output file to download:", output_files) | |
| if st.button("Download Output Document"): | |
| file_path = WORKING_DIRECTORY / output_file | |
| if file_path.exists(): | |
| with file_path.open("rb") as file: | |
| st.download_button( | |
| label="π₯ Download Output Document", | |
| data=file, | |
| file_name=output_file, | |
| ) | |
| else: | |
| st.write("Output document not found.") | |
| else: | |
| st.write("No output files available for download.") | |
| # Cleanup | |
| if st.button("Clear Working Directory"): | |
| for file in WORKING_DIRECTORY.iterdir(): | |
| if file.is_file(): | |
| file.unlink() | |
| st.success("Working directory cleared.") | |