File size: 5,006 Bytes
9fb199c
d3a1a1a
d885880
9fb199c
 
685b36d
26c24d0
d740763
55f8ad2
 
d885880
55f8ad2
9fb199c
7a7a231
 
8e56b06
4925b03
7a7a231
 
 
 
9fb199c
e75d735
7a7a231
9fb199c
26c24d0
d885880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d740763
d885880
 
 
6a9721c
 
 
13e8626
6a9721c
d885880
ee9c753
 
efd70d5
 
 
 
 
 
 
 
 
 
 
ee9c753
efd70d5
ee9c753
e75d735
7a7a231
685b36d
e75d735
 
 
 
 
 
 
 
3b80656
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
efd70d5
3b80656
 
e75d735
 
 
685b36d
 
 
 
 
 
 
 
 
 
 
ce86dad
9ad9538
 
685b36d
 
 
 
 
 
 
 
 
ce86dad
685b36d
ce86dad
 
 
 
685b36d
 
 
 
 
 
0f7bfc3
685b36d
 
 
e75d735
 
3b80656
e75d735
 
 
 
 
 
 
 
 
 
 
 
 
9fb199c
e75d735
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
from typing import TypedDict, List, Dict, Any, Optional, Union
from langchain_core import tools
from langgraph.graph import StateGraph, START, END
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFacePipeline
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from langchain_core.tools import tool
from ddgs import DDGS
from dotenv import load_dotenv


load_dotenv()

# Base Hugging Face LLM used by the chat wrapper
base_llm = HuggingFaceEndpoint(
    repo_id="deepseek-ai/DeepSeek-R1-0528",
    # deepseek-ai/DeepSeek-OCR:novita
    task="text-generation",
    temperature=0.0,
    huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN"),
)

# Chat model that works with LangGraph
model = ChatHuggingFace(llm=base_llm)

@tool
def web_search(keywords: str, max_results:int = 5) -> str:
    """
    Uses duckduckgo to search the web

    Use cases:
     - Identify personal information
     - Information search
     - Finding organisation information
     - Obtain the latest news

     Args:
        keywords: keywords used to search the web
        max_results: number of results to show after searching the web, defaults to 5

    Returns:
        Search result (Header + body + url)
    """
    with DDGS() as ddgs:
        # Perform a text search
        output = ""
        results = ddgs.text(keywords, max_results = max_results)
        for result in results:
            output += f"Results: {result['title']}\n{result['body']}\n{result['href']}\n\n"
        return(output)

@tool
def get_image_file(task_id):
    """
    Get the image file from the question
    Use cases:
     - Extract Image from the question

     Args:
        task_id: the task_id of the question

    Returns:
        Image file result
    """
    pass
    return ''


class AgentState(TypedDict):
    messages: List[Union[HumanMessage, AIMessage]]


def read_message(state: AgentState) -> AgentState:
    messages = state["messages"]
    print(f"Processing question: {messages[-1].content if messages else ''}")
    # Just pass the messages through to the next node
    return {"messages": messages}

# def tool_message(state: AgentState) -> AgentState:
#     messages = state["messages"]
#     prompt = f"""
#     You are a GAIA question answering expert. 
#     Your task is to decide whether to use a tool or not.
#     If you need to use a tool, answer ONLY:
#         CALL_TOOL: <your tool name>
#     If you do not need to use a tool, answer ONLY:
#         NO_TOOL
#     Here is the question:
#     {messages}
#     """
#     return {"messages": messages}
#     response = model_with_tools.invoke(prompt)
#     return {"messages": messages + [response]}

# Augment the LLM with tools
tools = [web_search,get_image_file]
tools_by_name = {tool.name: tool for tool in tools}
model_with_tools = model.bind_tools(tools)

def answer_message(state: AgentState) -> AgentState:
    messages = state["messages"]
    prompt = [SystemMessage(f"""
    You are a GAIA question answering expert. 
    Your task is to provide an answer to a question. 
    Think carefully before answering the question. 
    Do not include any thought process before answering the question, and only response exactly what was being asked of you.
    If you are not able to provide an answer, use tools or state the limitation that you're facing instead. 

    Example question: How many hours are there in a day?
    Response: 24    
    """)]
    messages = prompt + messages
    ai_msg = model_with_tools.invoke(messages)
    messages.append(ai_msg)

    # Step 2: Execute tools and collect results
    for tool_call in ai_msg.tool_calls:
        # Execute the tool with the generated arguments
        name = tool_call['name']
        args = tool_call['args']
        tool = tools_by_name[name]
        tool_result = tool.invoke(args)
        messages.append(tool_result)
    
    final_instruction = HumanMessage(
    content=(
            "Using the tool results above, provide the FINAL answer now. "
            "Do not call any tools. Respond with only the answer."
        )
    )
    messages.append(final_instruction)

    final_response = model_with_tools.invoke(messages)

    # final_response = model_with_tools.invoke(messages)
    print(f"Final response: {final_response}")
    final_response = final_response.content.split('</think>')[1].trim()

    # Append the model's answer to the messages list
    return {"messages": [final_response]}



def build_graph():
    agent_graph = StateGraph(AgentState)

    # Add nodes
    agent_graph.add_node("read_message", read_message)
    agent_graph.add_node("answer_message", answer_message)

    # Add edges
    agent_graph.add_edge(START, "read_message")
    agent_graph.add_edge("read_message", "answer_message")

    # Final edge
    agent_graph.add_edge("answer_message", END)

    # Compile and return the executable graph for use in app.py
    compiled_graph = agent_graph.compile()
    return compiled_graph