| import os |
|
|
| import requests |
| from dotenv import load_dotenv |
| from langchain_core.messages import SystemMessage, HumanMessage |
| from langgraph.constants import END |
| from langgraph.prebuilt import ToolNode |
|
|
| from prompts import AGENT_SYSTEM_PROMPT, build_question_prompt |
| from state import AgentState |
| from tools import wikipedia_tool, arxiv_tool, get_current_year, ddg_search_tool,get_youtube_transcript, fetch_url_content,get_gaia_file |
| from langchain_groq import ChatGroq |
|
|
| from langgraph.graph import StateGraph |
|
|
| load_dotenv() |
| HF_TOKEN = os.getenv("HF_TOKEN") |
| HF_USERNAME = os.getenv("HF_USERNAME") |
| GROQ_API_KEY= os.getenv("GROQ_API_KEY") |
| RECURSION_LIMIT = 50 |
| AGENT_CODE_URL = os.getenv("AGENT_CODE_URL", "https://huggingface.co/spaces/manasajanj/Final_Assignment_Template/tree/main") |
| llm = ChatGroq(model="openai/gpt-oss-120b") |
| tools_list = [wikipedia_tool, ddg_search_tool, arxiv_tool, get_current_year,get_youtube_transcript,fetch_url_content,get_gaia_file] |
| llm_with_tools = llm.bind_tools(tools_list) |
|
|
|
|
| def agent_node(state:AgentState) -> dict: |
| messages = state["messages"] |
| system_message = SystemMessage(content=AGENT_SYSTEM_PROMPT) |
| has_system_message = any(isinstance(m, SystemMessage) for m in messages) |
| if not has_system_message: |
| messages = [system_message]+messages |
| response = llm_with_tools.invoke(messages) |
| return {"messages": [response]} |
| def finish_node(state:AgentState) -> dict: |
| last_message = state["messages"][-1] |
| return {"answer": last_message.content.strip()} |
| def should_continue(state: AgentState) -> str: |
| last = state["messages"][-1] |
| if hasattr(last, "tool_calls") and last.tool_calls: |
| return "tools" |
| return "end" |
| def build_graph() -> StateGraph: |
| graph = StateGraph(AgentState) |
| graph.add_node("agent", agent_node) |
| graph.add_node("tools", ToolNode(tools_list)) |
| graph.add_node("finish", finish_node) |
|
|
| graph.set_entry_point("agent") |
| graph.add_conditional_edges("agent", should_continue, {"tools":"tools", "end":"finish"}) |
|
|
| graph.add_edge("tools", "agent") |
|
|
| graph.add_edge("finish", END) |
| return graph.compile() |
| app = build_graph() |
|
|
| def solve_question(task: dict) -> str: |
| question = task["question"] |
| file_name = task.get("file_name") or "" |
| user_content = build_question_prompt( |
| question=question, |
| file_name=file_name) |
| initial_state: AgentState = AgentState( |
| messages = [HumanMessage(content=user_content)], |
| task_id = task["task_id"], |
| question = question, |
| file_name=file_name or None, |
| answer=None, |
| ) |
|
|
| result = app.invoke(initial_state, config={"recursion_limit":RECURSION_LIMIT}) |
| return result['answer'] |
|
|
|
|
|
|
|
|
|
|
|
|