Spaces:
Sleeping
Sleeping
| import os | |
| import operator | |
| import functools | |
| from typing import Annotated, Sequence, TypedDict, Union, Optional | |
| from dotenv import load_dotenv | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.messages import BaseMessage, HumanMessage, AIMessage | |
| from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain_core.runnables import Runnable | |
| from langchain.output_parsers.openai_tools import JsonOutputKeyToolsParser | |
| from langgraph.graph import StateGraph, END | |
| from application.agents.scraper_agent import scraper_agent | |
| from application.agents.extractor_agent import extractor_agent | |
| from application.utils.logger import get_logger | |
| load_dotenv() | |
| logger = get_logger() | |
| OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") | |
| if not OPENAI_API_KEY: | |
| logger.error("OPENAI_API_KEY is missing. Please set it in your environment variables.") | |
| raise EnvironmentError("OPENAI_API_KEY not found in environment variables.") | |
| MEMBERS = ["Scraper", "Extractor"] | |
| OPTIONS = ["FINISH"] + MEMBERS | |
| SUPERVISOR_SYSTEM_PROMPT = ( | |
| "You are a supervisor tasked with managing a conversation between the following workers: {members}. " | |
| "Given the user's request and the previous messages, determine what to do next:\n" | |
| "- If the user asks to search, find, or scrape data from the web, choose 'Scraper'.\n" | |
| "- If the user asks to extract ESG emissions data from a file or PDF, choose 'Extractor'.\n" | |
| "- If the task is complete, choose 'FINISH'.\n" | |
| "- If the message is general conversation (like greetings, questions, thanks, chatting), directly respond with a message.\n" | |
| "Each worker will perform its task and report back.\n" | |
| "When you respond directly, make sure your message is friendly and helpful." | |
| ) | |
| FUNCTION_DEF = { | |
| "name": "route_or_respond", | |
| "description": "Select the next role OR respond directly.", | |
| "parameters": { | |
| "title": "RouteOrRespondSchema", | |
| "type": "object", | |
| "properties": { | |
| "next": { | |
| "title": "Next Worker", | |
| "anyOf": [{"enum": OPTIONS}], | |
| "description": "Choose next worker if needed." | |
| }, | |
| "response": { | |
| "title": "Supervisor Response", | |
| "type": "string", | |
| "description": "Respond directly if no worker action is needed." | |
| } | |
| }, | |
| "required": [], | |
| }, | |
| } | |
| class AgentState(TypedDict): | |
| messages: Annotated[Sequence[BaseMessage], operator.add] | |
| next: Optional[str] | |
| response: Optional[str] | |
| llm = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
| def agent_node(state: AgentState, agent: Runnable, name: str) -> dict: | |
| logger.info(f"Agent {name} invoked.") | |
| try: | |
| result = agent.invoke(state) | |
| logger.info(f"Agent {name} completed successfully.") | |
| return {"messages": [HumanMessage(content=result["output"], name=name)]} | |
| except Exception as e: | |
| logger.exception(f"Agent {name} failed with error: {str(e)}") | |
| raise | |
| prompt = ChatPromptTemplate.from_messages( | |
| [ | |
| ("system", SUPERVISOR_SYSTEM_PROMPT), | |
| MessagesPlaceholder(variable_name="messages"), | |
| ( | |
| "system", | |
| "Based on the conversation, either select next worker (one of: {options}) or respond directly with a message.", | |
| ), | |
| ] | |
| ).partial(options=str(OPTIONS), members=", ".join(MEMBERS)) | |
| # supervisor_chain = ( | |
| # prompt | |
| # | llm.bind_functions(functions=[FUNCTION_DEF], function_call="route_or_respond") | |
| # | JsonOutputFunctionsParser() | |
| # ) | |
| supervisor_chain = ( | |
| prompt | |
| | llm.bind_tools(tools=[FUNCTION_DEF], tool_choice="route_or_respond") | |
| | JsonOutputKeyToolsParser(key_name="route_or_respond") | |
| ) | |
| def supervisor_node(state: AgentState) -> AgentState: | |
| logger.info("Supervisor invoked.") | |
| output = supervisor_chain.invoke(state) | |
| logger.info(f"Supervisor output: {output}") | |
| if isinstance(output, list) and len(output) > 0: | |
| output = output[0] | |
| next_step = output.get("next") | |
| response = output.get("response") | |
| if not next_step and not response: | |
| raise ValueError(f"Supervisor produced invalid output: {output}") | |
| return { | |
| "messages": state["messages"], | |
| "next": next_step, | |
| "response": response, | |
| } | |
| workflow = StateGraph(AgentState) | |
| workflow.add_node("Scraper", functools.partial(agent_node, agent=scraper_agent, name="Scraper")) | |
| workflow.add_node("Extractor", functools.partial(agent_node, agent=extractor_agent, name="Extractor")) | |
| workflow.add_node("supervisor", supervisor_node) | |
| # workflow.add_node("supervisor", supervisor_chain) | |
| workflow.add_node("supervisor_response", lambda state: {"messages": [AIMessage(content=state["response"], name="Supervisor")]}) | |
| for member in MEMBERS: | |
| workflow.add_edge(member, "supervisor") | |
| def router(state: AgentState): | |
| if state.get("response"): | |
| return "supervisor_response" | |
| return state.get("next") | |
| conditional_map = {member: member for member in MEMBERS} | |
| conditional_map["FINISH"] = END | |
| conditional_map["supervisor_response"] = "supervisor_response" | |
| workflow.add_conditional_edges("supervisor", router, conditional_map) | |
| workflow.set_entry_point("supervisor") | |
| graph = workflow.compile() | |
| # # === Example Run === | |
| # if __name__ == "__main__": | |
| # logger.info("Starting the graph execution...") | |
| # initial_message = HumanMessage(content="Can you get zalando pdf link") | |
| # input_state = {"messages": [initial_message]} | |
| # for step in graph.stream(input_state): | |
| # if "__end__" not in step: | |
| # logger.info(f"Graph Step Output: {step}") | |
| # print(step) | |
| # print("----") | |
| # logger.info("Graph execution completed.") |