Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| import json | |
| from langchain_core.messages import ToolMessage | |
| from typing import TypedDict, Annotated | |
| from langgraph.graph.message import add_messages | |
| from typing import Annotated, List | |
| from langgraph.graph import StateGraph, END | |
| from langchain_core.messages import ToolMessage, HumanMessage | |
| from langchain_openai import ChatOpenAI | |
| # Local Imports | |
| from application.tools.web_search_tools import get_top_companies_from_web, get_sustainability_report_pdf | |
| from application.tools.pdf_downloader_tool import download_pdf | |
| from application.tools.emission_data_extractor import extract_emission_data_as_json | |
| from application.services.langgraph_service import create_agent | |
| from application.utils.logger import get_logger | |
| # setting up environment and logger | |
| load_dotenv() | |
| logger = get_logger() | |
| # Langsmith | |
| LANGSMITH_API_KEY=os.getenv('LANGSMITH_API_KEY') | |
| os.environ['LANGSMITH_API_KEY'] = LANGSMITH_API_KEY | |
| os.environ['LANGCHAIN_TRACING_V2'] = 'true' | |
| os.environ["LANGCHAIN_PROJECT"] = "Sustainability_AI" | |
| # OpenAI api key set up | |
| os.environ['OPENAI_API_KEY'] = os.environ.get("OPENAI_API_KEY") | |
| class AgentState(TypedDict): | |
| messages: Annotated[List, add_messages] | |
| graph = StateGraph(AgentState) | |
| model = ChatOpenAI(model= 'gpt-4o-mini', temperature=0) | |
| tools = [get_top_companies_from_web, get_sustainability_report_pdf, download_pdf, extract_emission_data_as_json] | |
| model_with_tools = model.bind_tools(tools) | |
| def invoke_model(state: AgentState) -> dict: | |
| """Invokes the LLM with the current conversation history.""" | |
| logger.info("--- Invoking Model ---") | |
| response = model_with_tools.invoke(state['messages']) | |
| return {"messages": [response]} | |
| def invoke_tools(state: AgentState) -> dict: | |
| """Invokes the necessary tools based on the last AI message.""" | |
| logger.info("--- Invoking Tools ---") | |
| last_message = state['messages'][-1] | |
| if not hasattr(last_message, 'tool_calls') or not last_message.tool_calls: | |
| logger.info("No tool calls found in the last message.") | |
| return {} | |
| tool_invocation_messages = [] | |
| tool_map = {tool.name: tool for tool in tools} | |
| for tool_call in last_message.tool_calls: | |
| tool_name = tool_call['name'] | |
| tool_args = tool_call['args'] | |
| tool_call_id = tool_call['id'] | |
| logger.info(f"Executing tool: {tool_name} with args: {tool_args}") | |
| if tool_name in tool_map: | |
| selected_tool = tool_map[tool_name] | |
| try: | |
| result = selected_tool.invoke(tool_args) | |
| if isinstance(result, list) or isinstance(result, dict): | |
| result_content = json.dumps(result) | |
| elif hasattr(result, 'companies') and isinstance(result.companies, list): | |
| result_content = f"Companies found: {', '.join(result.companies)}" | |
| elif result is None: | |
| result_content = "Tool executed successfully, but returned no specific data (None)." | |
| else: | |
| result_content = str(result) | |
| logger.info(f"Tool {tool_name} result: {result_content}") | |
| tool_invocation_messages.append( | |
| ToolMessage(content=result_content, tool_call_id=tool_call_id) | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error executing tool {tool_name}: {e}") | |
| tool_invocation_messages.append( | |
| ToolMessage(content=f"Error executing tool {tool_name}: {str(e)}", tool_call_id=tool_call_id) | |
| ) | |
| else: | |
| logger.warning(f"Tool '{tool_name}' not found.") | |
| tool_invocation_messages.append( | |
| ToolMessage(content=f"Error: Tool '{tool_name}' not found.", tool_call_id=tool_call_id) | |
| ) | |
| return {"messages": tool_invocation_messages} | |
| graph_builder = StateGraph(AgentState) | |
| graph_builder.add_node("scraper_agent", invoke_model) | |
| graph_builder.add_node("tools", invoke_tools) | |
| graph_builder.set_entry_point("scraper_agent") | |
| def router(state: AgentState) -> str: | |
| """Determines the next step based on the last message.""" | |
| last_message = state['messages'][-1] | |
| if hasattr(last_message, 'tool_calls') and last_message.tool_calls: | |
| logger.info("--- Routing to Tools ---") | |
| return "tools" | |
| else: | |
| logger.info("--- Routing to End ---") | |
| return END | |
| graph_builder.add_conditional_edges( | |
| "scraper_agent", | |
| router, | |
| { | |
| "tools": "tools", | |
| END: END, | |
| } | |
| ) | |
| graph_builder.add_edge("tools", "scraper_agent") | |
| # Compile the graph | |
| app = graph_builder.compile() | |
| # # --- Running the Graph --- | |
| # if __name__ == "__main__": | |
| # logger.info("Starting graph execution...") | |
| # # Use HumanMessage for the initial input | |
| # initial_input = {"messages": [HumanMessage(content="Please download this pdf https://www.infosys.com/sustainability/documents/infosys-esg-report-2023-24.pdf")]} | |
| # # Stream events to see the flow (optional, but helpful for debugging) | |
| # # Add recursion limit to prevent infinite loops | |
| # try: | |
| # final_state = None | |
| # for event in app.stream(initial_input, {"recursion_limit": 15}): | |
| # # event is a dictionary where keys are node names and values are outputs | |
| # logger.info(f"Event: {event}") | |
| # # Keep track of the latest state if needed, especially the messages | |
| # if "scraper_agent" in event: | |
| # final_state = event["scraper_agent"] | |
| # elif "tools" in event: | |
| # final_state = event["tools"] # Though tool output doesn't directly give full state | |
| # logger.info("---") | |
| # logger.info("\n--- Final State Messages ---") | |
| # # To get the absolute final state after streaming, invoke might be simpler, | |
| # # or you need to properly aggregate the state from the stream events. | |
| # # A simpler way to get final output: | |
| # final_output = app.invoke(initial_input, {"recursion_limit": 15}) | |
| # logger.info(json.dumps(final_output['messages'][-1].dict(), indent=2)) # Print the last message | |
| # except Exception as e: | |
| # logger.error(f"\n--- An error occurred during graph execution ---") | |
| # import traceback | |
| # traceback.print_exc() | |
| SCRAPER_SYSTEM_PROMPT = """ | |
| You are an intelligent assistant specialized in company research and sustainability report retrieval. | |
| You have access to the following tools: | |
| - **search_tool**: Use this tool when the user asks for a list of top companies related to an industry or category (e.g., "top 5 textile companies"). Always preserve any number mentioned (e.g., 'top 5', 'top 10') in the query. | |
| - **pdf_finder_tool**: Use this tool when the user requests a sustainability report or any other specific PDF document about a company. Search specifically for the latest sustainability report if not otherwise specified. | |
| - **pdf_downloader_tool**: Use this tool when the user provides a direct PDF link or asks you to download a PDF document from a URL. | |
| Instructions: | |
| - Carefully read the user's request and select the correct tool based on their intent. | |
| - Always preserve important details like quantity (e.g., "top 5"), industry, or company name. | |
| - If the user mentions multiple companies and asks for reports, find reports for **each** company individually. | |
| - Do not add assumptions, opinions, or unrelated information. | |
| - Always generate clean, direct, and minimal input for the tool β close to the user's original query. | |
| - Prioritize the most recent information when searching for reports unless otherwise instructed. | |
| Goal: | |
| - Select the appropriate tool. | |
| - Build a precise query that perfectly reflects the user's request. | |
| - Return only what the user asks β no extra text or interpretation. | |
| """ | |
| search_tool = get_top_companies_from_web | |
| pdf_finder_tool = get_sustainability_report_pdf | |
| pdf_downloader_tool = download_pdf | |
| llm = ChatOpenAI(model= 'gpt-4o-mini', temperature=0) | |
| scraper_agent = create_agent(llm, [search_tool, pdf_finder_tool, pdf_downloader_tool], SCRAPER_SYSTEM_PROMPT) |