File size: 3,380 Bytes
9676611
a99da62
26a1143
ee4ab6d
728aee3
a99da62
9676611
ee4ab6d
 
a99da62
ee4ab6d
 
 
a99da62
 
 
 
b351c59
a99da62
ace61ef
ee4ab6d
 
728aee3
 
 
 
 
707829f
 
 
 
 
273a786
13b5d5f
728aee3
1fd978d
 
13b5d5f
707829f
273a786
d712a83
b351c59
26a1143
a99da62
728aee3
a99da62
 
 
 
728aee3
a99da62
 
 
 
 
728aee3
a99da62
 
 
 
 
b351c59
a99da62
 
 
 
 
 
 
 
 
 
26a1143
728aee3
26a1143
728aee3
 
9676611
 
 
 
 
 
 
728aee3
9676611
 
76b7238
728aee3
76b7238
1fd978d
728aee3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from pathlib import Path
from typing import TypedDict, Annotated
from uuid import uuid4
import os
import requests

from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_google_genai import ChatGoogleGenerativeAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.graph import START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode
from langgraph.prebuilt import tools_condition
from smolagents import CodeAgent, HfApiModel

# Import our custom tools from their modules
from tools import basic_tools

# Google's chat interface
RPM = int(os.environ.get("AGENT_MODEL_RPM", 8))
TPM = int(os.environ.get("AGENT_MODEL_TPM", 200000))
FILES_ENDPOINT = os.environ.get(
    "FILES_ENDPOINT", "https://agents-course-unit4-scoring.hf.space"
)
TARGET_FILES_DIR = os.environ.get("TARGET_FILES_DIR", "/tmp/task_file")

limiter = InMemoryRateLimiter(
    requests_per_second=(RPM / 60),
    check_every_n_seconds=(RPM / 70),
    max_bucket_size=RPM,  # Controls the maximum burst size.
)
chat = ChatGoogleGenerativeAI(
    # model="gemini-2.0-flash-lite",
    model=os.environ.get("AGENT_MODEL", "gemini-2.5-flash-preview-04-17"),
    temperature=os.environ.get("AGENT_MODEL_TEMP", 0.25),
    max_retries=os.environ.get("AGENT_MODEL_RETRIES", 2),
    verbose=True,
    rate_limiter=limiter,
)

chat_with_tools = chat.bind_tools(basic_tools)
memory = MemorySaver()


# Generate the AgentState and Agent graph
class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


def assistant(state: AgentState):
    return {
        "messages": [chat_with_tools.invoke(state["messages"])],
    }


## The graph
builder = StateGraph(AgentState)

# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(basic_tools))

# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
    "assistant",
    # If the latest message requires a tool, route to tools
    # Otherwise, provide a direct response
    tools_condition,
)
builder.add_edge("tools", "assistant")


def create_config():
    return {"configurable": {"thread_id": str(uuid4())}}


def get_system_prompt(prompt_file: Path = None):
    if prompt_file is None:
        prompt_file = Path("system_prompt.txt")
    # load the system prompt from the file
    with prompt_file.open("r", encoding="utf-8") as f:
        system_prompt = f.read()

    # System message
    return SystemMessage(content=system_prompt)


def insert_file_into_query(query: str, file_name: str = ""):
    return f"""{query} - Adjacent files path > {file_name}"""


def download_requested_file(
    task_id: str,
    question_file: str,
    endpoint: str = FILES_ENDPOINT,
    target_dir: str = TARGET_FILES_DIR,
):
    if question_file == "":
        return

    target_path = Path(target_dir)
    if not target_path.exists():
        target_path.mkdir(parents=True)
    # Create path
    file_path = target_path / question_file
    # Download file
    request = requests.get(
        f"{endpoint}/files/{task_id}", timeout=30, allow_redirects=True
    )
    with file_path.open("wb") as file_:
        file_.write(request.content)
    return file_path