File size: 4,932 Bytes
e1dc6ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc7c681
1ccd98f
e1dc6ad
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import time
from dotenv import load_dotenv
from typing import TypedDict, Annotated, Optional

from langgraph.prebuilt import ToolNode, tools_condition
from langgraph.graph import StateGraph, START
from langgraph.graph.message import add_messages

from langchain_core.messages import AnyMessage, SystemMessage, HumanMessage
from langchain_google_genai import ChatGoogleGenerativeAI

from tools import *

load_dotenv()
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"

class AgentState(TypedDict):
    """Agent state for the graph."""
    input_file: Optional[str]
    messages: Annotated[list[AnyMessage], add_messages]


class GEMINI_AGENT:
    def __init__(self):
        self.llm = ChatGoogleGenerativeAI(
            model="gemini-2.0-flash-lite",
            temperature=0,
            max_tokens=1024,
            google_api_key=os.getenv("GEMINI_API_KEY"),
        )
        
        self.tools = [
            duckduck_websearch,
            serper_websearch,
            visit_webpage,
            wiki_search,
            youtube_viewer,
            text_splitter,
            read_file,
            excel_read,
            csv_read,
            mp3_listen,
            image_caption,
            run_python,
            multiply,
            add,
            subtract,
            divide
        ]

        self.llm_with_tools = self.llm.bind_tools(self.tools)
        self.app = self._graph_compile()

    def _graph_compile(self):
        builder = StateGraph(AgentState)
        # Define nodes: these do the work
        builder.add_node("assistant", self._assistant)
        builder.add_node("tools", ToolNode(self.tools))
        # Define edges: these determine how the control flow moves
        builder.add_edge(START, "assistant")
        builder.add_conditional_edges(
            "assistant",
            tools_condition,
        )
        builder.add_edge("tools", "assistant")
        react_graph = builder.compile()
        return react_graph

    def _assistant(self, state: AgentState):
        sys_msg = SystemMessage(
            content=
            """
            You are a helpful assistant tasked with answering questions using a set of tools. When given a question, follow these steps:
            1. Create a clear, step-by-step plan to solve the question.
            2. If a tool is necessary, select the most appropriate tool based on its functionality. If one tool isn't working, use another with similar functionality.
            3. Execute your plan and provide the response in the following format:

            FINAL ANSWER: [YOUR FINAL ANSWER]

            Your final answer should be:

            - A number (without commas or units unless explicitly requested),
            - A short string (avoid articles, abbreviations, and use plain text for digits unless otherwise specified),
            - A comma-separated list (apply the formatting rules above for each element, with exactly one space after each comma).

            Ensure that your answer is concise and follows the task instructions strictly. If the answer is more complex, break it down in a way that follows the format.
            Begin your response with "FINAL ANSWER: " followed by the answer, and nothing else.
            """
        )

        return {
            "messages": [self.llm_with_tools.invoke([sys_msg] + state["messages"])],
            "input_file": state["input_file"]
        }

    def extract_after_final_answer(self, text):
        keyword = "FINAL ANSWER: "
        index = text.find(keyword)
        if index != -1:
            return text[index + len(keyword):]
        else:
            return ""

    def run(self, task: dict):
        task_id, question, file_name = task["task_id"], task["question"], task["file_name"]
        print(f"Agent received question (first 50 chars): {question[:50]}...")

        if file_name == "" or file_name is None:
            question_text = question
        else:
            question_text = f'{question} with TASK-ID: {task_id}'
        messages = [HumanMessage(content=question_text)]

        max_retries = 5
        base_sleep = 1
        for attempt in range(max_retries):
            try:
                response = self.app.invoke({"messages": messages, "input_file": None})
                final_ans = self.extract_after_final_answer(response['messages'][-1].content)
                time.sleep(60) # avoid rate limit
                return final_ans
            except Exception as e:
                sleep_time = base_sleep * (attempt + 1)
                if attempt < max_retries - 1:
                    print(str(e))
                    print(f"Attempt {attempt + 1} failed. Retrying in {sleep_time} seconds...")
                    time.sleep(sleep_time)
                    continue
                return f"Error processing query after {max_retries} attempts: {str(e)}"
        return "This is a default answer."