File size: 2,785 Bytes
0a1e1d3
 
 
 
 
 
 
 
 
 
cee86e1
0a1e1d3
 
 
 
 
 
 
287b070
0a1e1d3
 
cc75b85
cee86e1
0a1e1d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os

import requests
from dotenv import load_dotenv
from langchain_core.messages import SystemMessage, HumanMessage
from langgraph.constants import END
from langgraph.prebuilt import ToolNode

from prompts import AGENT_SYSTEM_PROMPT, build_question_prompt
from state import AgentState
from tools import wikipedia_tool, arxiv_tool, get_current_year, ddg_search_tool,get_youtube_transcript, fetch_url_content,get_gaia_file
from langchain_groq import ChatGroq

from langgraph.graph import StateGraph

load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
HF_USERNAME = os.getenv("HF_USERNAME")
GROQ_API_KEY= os.getenv("GROQ_API_KEY")
RECURSION_LIMIT = 50
AGENT_CODE_URL = os.getenv("AGENT_CODE_URL", "https://huggingface.co/spaces/manasajanj/Final_Assignment_Template/tree/main")
llm = ChatGroq(model="openai/gpt-oss-120b")
tools_list = [wikipedia_tool, ddg_search_tool, arxiv_tool, get_current_year,get_youtube_transcript,fetch_url_content,get_gaia_file]
llm_with_tools = llm.bind_tools(tools_list)


def agent_node(state:AgentState) -> dict:
    messages = state["messages"]
    system_message = SystemMessage(content=AGENT_SYSTEM_PROMPT)
    has_system_message = any(isinstance(m, SystemMessage) for m in messages)
    if not has_system_message:
        messages = [system_message]+messages
    response = llm_with_tools.invoke(messages)
    return {"messages": [response]}
def finish_node(state:AgentState) -> dict:
    last_message = state["messages"][-1]
    return {"answer": last_message.content.strip()}
def should_continue(state: AgentState) -> str:
    last = state["messages"][-1]
    if hasattr(last, "tool_calls") and last.tool_calls:
        return "tools"
    return "end"
def build_graph() -> StateGraph:
    graph = StateGraph(AgentState)
    graph.add_node("agent", agent_node)
    graph.add_node("tools", ToolNode(tools_list))
    graph.add_node("finish", finish_node)

    graph.set_entry_point("agent")
    graph.add_conditional_edges("agent", should_continue, {"tools":"tools", "end":"finish"})

    graph.add_edge("tools", "agent") # go back to agent after tools

    graph.add_edge("finish", END)
    return graph.compile()
app = build_graph()

def solve_question(task: dict) -> str:
    question = task["question"]
    file_name = task.get("file_name") or ""
    user_content = build_question_prompt(
        question=question,
        file_name=file_name)
    initial_state: AgentState = AgentState(
                    messages = [HumanMessage(content=user_content)],
                    task_id = task["task_id"],
                    question = question,
                    file_name=file_name or None,
                    answer=None,
    )

    result = app.invoke(initial_state, config={"recursion_limit":RECURSION_LIMIT})
    return result['answer']