| import os |
| import chromadb |
| import streamlit as st |
| from langchain_openai import ChatOpenAI |
| from langchain.agents import AgentExecutor, create_openai_tools_agent |
| from langchain_core.messages import BaseMessage, HumanMessage |
| from langchain_community.tools.tavily_search import TavilySearchResults |
| from langchain_experimental.tools import PythonREPLTool |
| from langchain_community.document_loaders import DirectoryLoader, TextLoader |
| from langchain.text_splitter import RecursiveCharacterTextSplitter |
| from langchain_community.vectorstores import Chroma |
| from langchain.embeddings import HuggingFaceBgeEmbeddings |
| from langchain_core.output_parsers import StrOutputParser |
| from langchain_core.runnables import RunnablePassthrough |
| from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
| from langgraph.graph import StateGraph, END |
| from langchain_core.documents import Document |
| from typing import Annotated, Sequence, TypedDict |
| import functools |
| import operator |
| from langchain_core.tools import tool |
| from glob import glob |
|
|
|
|
| |
| chromadb.api.client.SharedSystemClient.clear_system_cache() |
|
|
| |
|
|
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
| TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") |
|
|
| if not OPENAI_API_KEY or not TAVILY_API_KEY: |
| st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your environment variables.") |
| st.stop() |
|
|
| |
| llm = ChatOpenAI(model="gpt-4-1106-preview", openai_api_key=OPENAI_API_KEY) |
|
|
| |
| def create_agent(llm: ChatOpenAI, tools: list, system_prompt: str): |
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", system_prompt), |
| MessagesPlaceholder(variable_name="messages"), |
| MessagesPlaceholder(variable_name="agent_scratchpad"), |
| ]) |
| agent = create_openai_tools_agent(llm, tools, prompt) |
| return AgentExecutor(agent=agent, tools=tools) |
|
|
| def agent_node(state, agent, name): |
| result = agent.invoke(state) |
| return {"messages": [HumanMessage(content=result["output"], name=name)]} |
|
|
| @tool |
| def RAG(state): |
| """Use this tool to execute RAG. If the question is related to Japan or Sports, this tool retrieves the results.""" |
| st.session_state.outputs.append('-> Calling RAG ->') |
| question = state |
| template = """Answer the question based only on the following context:\n{context}\nQuestion: {question}""" |
| prompt = ChatPromptTemplate.from_template(template) |
| retrieval_chain = ( |
| {"context": retriever, "question": RunnablePassthrough()} | |
| prompt | |
| llm | |
| StrOutputParser() |
| ) |
| result = retrieval_chain.invoke(question) |
| return result |
|
|
| |
| tavily_tool = TavilySearchResults(max_results=5, tavily_api_key=TAVILY_API_KEY) |
| python_repl_tool = PythonREPLTool() |
|
|
| |
| st.title("Multi-Agent w Supervisor") |
|
|
| |
| example_questions = [ |
| |
| "What is James McIlroy aiming for in sports?", |
| "Fetch India's GDP over the past 5 years and draw a line graph.", |
| "Fetch Japan's GDP over the past 4 years from RAG, then draw a line graph." |
| ] |
|
|
| |
| source_files = glob("sources/*.txt") |
| selected_files = st.multiselect("Select files from the source directory:", source_files, default=source_files[:2]) |
|
|
| uploaded_files = st.file_uploader("Or upload your TXT files:", accept_multiple_files=True, type=['txt']) |
|
|
| |
| all_docs = [] |
| if selected_files: |
| for file_path in selected_files: |
| loader = TextLoader(file_path) |
| all_docs.extend(loader.load()) |
|
|
| if uploaded_files: |
| for uploaded_file in uploaded_files: |
| content = uploaded_file.read().decode("utf-8") |
| all_docs.append(Document(page_content=content, metadata={"name": uploaded_file.name})) |
|
|
| if not all_docs: |
| st.warning("Please select files from the source directory or upload TXT files.") |
| st.stop() |
|
|
| |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=100, chunk_overlap=10, length_function=len) |
| split_docs = text_splitter.split_documents(all_docs) |
|
|
| embeddings = HuggingFaceBgeEmbeddings(model_name="BAAI/bge-base-en-v1.5", model_kwargs={'device': 'cpu'}, encode_kwargs={'normalize_embeddings': True}) |
| db = Chroma.from_documents(split_docs, embeddings) |
| retriever = db.as_retriever(search_kwargs={"k": 4}) |
|
|
| |
| research_agent = create_agent(llm, [tavily_tool], "You are a web researcher.") |
| code_agent = create_agent(llm, [python_repl_tool], "You may generate safe python code to analyze data and generate charts using matplotlib.") |
| RAG_agent = create_agent(llm, [RAG], "Use this tool when questions are related to Japan or Sports category.") |
|
|
| research_node = functools.partial(agent_node, agent=research_agent, name="Researcher") |
| code_node = functools.partial(agent_node, agent=code_agent, name="Coder") |
| rag_node = functools.partial(agent_node, agent=RAG_agent, name="RAG") |
|
|
| members = ["RAG", "Researcher", "Coder"] |
| system_prompt = ( |
| "You are a supervisor managing these workers: {members}. Respond with the next worker or FINISH. " |
| "Use RAG tool for Japan or Sports questions." |
| ) |
| options = ["FINISH"] + members |
| function_def = { |
| "name": "route", "description": "Select the next role.", |
| "parameters": {"title": "routeSchema", "type": "object", "properties": {"next": {"anyOf": [{"enum": options}]}}, "required": ["next"]} |
| } |
| prompt = ChatPromptTemplate.from_messages([ |
| ("system", system_prompt), |
| MessagesPlaceholder(variable_name="messages"), |
| ("system", "Given the conversation above, who should act next? Select one of: {options}"), |
| ]).partial(options=str(options), members=", ".join(members)) |
|
|
| supervisor_chain = (prompt | llm.bind_functions(functions=[function_def], function_call="route") | JsonOutputFunctionsParser()) |
|
|
| |
| class AgentState(TypedDict): |
| messages: Annotated[Sequence[BaseMessage], operator.add] |
| next: str |
|
|
| workflow = StateGraph(AgentState) |
| workflow.add_node("Researcher", research_node) |
| workflow.add_node("Coder", code_node) |
| workflow.add_node("RAG", rag_node) |
| workflow.add_node("supervisor", supervisor_chain) |
|
|
| for member in members: |
| workflow.add_edge(member, "supervisor") |
| conditional_map = {k: k for k in members} |
| conditional_map["FINISH"] = END |
| workflow.add_conditional_edges("supervisor", lambda x: x["next"], conditional_map) |
| workflow.set_entry_point("supervisor") |
| graph = workflow.compile() |
|
|
| |
| if 'outputs' not in st.session_state: |
| st.session_state.outputs = [] |
|
|
| user_input = st.text_area("Enter your task or question:", placeholder=example_questions[0]) |
|
|
| def run_workflow(task): |
| st.session_state.outputs.clear() |
| st.session_state.outputs.append(f"User Input: {task}") |
| for state in graph.stream({"messages": [HumanMessage(content=task)]}): |
| if "__end__" not in state: |
| st.session_state.outputs.append(str(state)) |
| st.session_state.outputs.append("----") |
|
|
| if st.button("Run Workflow"): |
| if user_input: |
| run_workflow(user_input) |
| else: |
| st.warning("Please enter a task or question.") |
|
|
| st.subheader("Example Questions:") |
| for example in example_questions: |
| st.text(f"- {example}") |
|
|
| st.subheader("Workflow Output:") |
| for output in st.session_state.outputs: |
| st.text(output) |
|
|