manasajanj's picture
Update agent.py
cee86e1 verified
import os
import requests
from dotenv import load_dotenv
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.constants import END
from langgraph.prebuilt import ToolNode
from prompts import AGENT_SYSTEM_PROMPT, build_question_prompt
from state import AgentState
from tools import wikipedia_tool, arxiv_tool, get_current_year, ddg_search_tool,get_youtube_transcript, fetch_url_content,get_gaia_file
from langchain_groq import ChatGroq
from langgraph.graph import StateGraph
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
HF_USERNAME = os.getenv("HF_USERNAME")
GROQ_API_KEY= os.getenv("GROQ_API_KEY")
RECURSION_LIMIT = 50
AGENT_CODE_URL = os.getenv("AGENT_CODE_URL", "https://huggingface.co/spaces/manasajanj/Final_Assignment_Template/tree/main")
llm = ChatGroq(model="openai/gpt-oss-120b")
tools_list = [wikipedia_tool, ddg_search_tool, arxiv_tool, get_current_year,get_youtube_transcript,fetch_url_content,get_gaia_file]
llm_with_tools = llm.bind_tools(tools_list)
def agent_node(state:AgentState) -> dict:
messages = state["messages"]
system_message = SystemMessage(content=AGENT_SYSTEM_PROMPT)
has_system_message = any(isinstance(m, SystemMessage) for m in messages)
if not has_system_message:
messages = [system_message]+messages
response = llm_with_tools.invoke(messages)
return {"messages": [response]}
def finish_node(state:AgentState) -> dict:
last_message = state["messages"][-1]
return {"answer": last_message.content.strip()}
def should_continue(state: AgentState) -> str:
last = state["messages"][-1]
if hasattr(last, "tool_calls") and last.tool_calls:
return "tools"
return "end"
def build_graph() -> StateGraph:
graph = StateGraph(AgentState)
graph.add_node("agent", agent_node)
graph.add_node("tools", ToolNode(tools_list))
graph.add_node("finish", finish_node)
graph.set_entry_point("agent")
graph.add_conditional_edges("agent", should_continue, {"tools":"tools", "end":"finish"})
graph.add_edge("tools", "agent") # go back to agent after tools
graph.add_edge("finish", END)
return graph.compile()
app = build_graph()
def solve_question(task: dict) -> str:
question = task["question"]
file_name = task.get("file_name") or ""
user_content = build_question_prompt(
question=question,
file_name=file_name)
initial_state: AgentState = AgentState(
messages = [HumanMessage(content=user_content)],
task_id = task["task_id"],
question = question,
file_name=file_name or None,
answer=None,
)
result = app.invoke(initial_state, config={"recursion_limit":RECURSION_LIMIT})
return result['answer']