|
|
from pathlib import Path |
|
|
from typing import TypedDict, Annotated |
|
|
from uuid import uuid4 |
|
|
import os |
|
|
import requests |
|
|
|
|
|
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage |
|
|
from langchain_core.rate_limiters import InMemoryRateLimiter |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from langgraph.checkpoint.memory import MemorySaver |
|
|
from langgraph.graph import START, StateGraph |
|
|
from langgraph.graph.message import add_messages |
|
|
from langgraph.prebuilt import ToolNode |
|
|
from langgraph.prebuilt import tools_condition |
|
|
from smolagents import CodeAgent, HfApiModel |
|
|
|
|
|
|
|
|
from tools import basic_tools |
|
|
|
|
|
|
|
|
RPM = int(os.environ.get("AGENT_MODEL_RPM", 8)) |
|
|
TPM = int(os.environ.get("AGENT_MODEL_TPM", 200000)) |
|
|
FILES_ENDPOINT = os.environ.get( |
|
|
"FILES_ENDPOINT", "https://agents-course-unit4-scoring.hf.space" |
|
|
) |
|
|
TARGET_FILES_DIR = os.environ.get("TARGET_FILES_DIR", "/tmp/task_file") |
|
|
|
|
|
limiter = InMemoryRateLimiter( |
|
|
requests_per_second=(RPM / 60), |
|
|
check_every_n_seconds=(RPM / 70), |
|
|
max_bucket_size=RPM, |
|
|
) |
|
|
chat = ChatGoogleGenerativeAI( |
|
|
|
|
|
model=os.environ.get("AGENT_MODEL", "gemini-2.5-flash-preview-04-17"), |
|
|
temperature=os.environ.get("AGENT_MODEL_TEMP", 0.25), |
|
|
max_retries=os.environ.get("AGENT_MODEL_RETRIES", 2), |
|
|
verbose=True, |
|
|
rate_limiter=limiter, |
|
|
) |
|
|
|
|
|
chat_with_tools = chat.bind_tools(basic_tools) |
|
|
memory = MemorySaver() |
|
|
|
|
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
|
messages: Annotated[list[AnyMessage], add_messages] |
|
|
|
|
|
|
|
|
def assistant(state: AgentState): |
|
|
return { |
|
|
"messages": [chat_with_tools.invoke(state["messages"])], |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
builder = StateGraph(AgentState) |
|
|
|
|
|
|
|
|
builder.add_node("assistant", assistant) |
|
|
builder.add_node("tools", ToolNode(basic_tools)) |
|
|
|
|
|
|
|
|
builder.add_edge(START, "assistant") |
|
|
builder.add_conditional_edges( |
|
|
"assistant", |
|
|
|
|
|
|
|
|
tools_condition, |
|
|
) |
|
|
builder.add_edge("tools", "assistant") |
|
|
|
|
|
|
|
|
def create_config(): |
|
|
return {"configurable": {"thread_id": str(uuid4())}} |
|
|
|
|
|
|
|
|
def get_system_prompt(prompt_file: Path = None): |
|
|
if prompt_file is None: |
|
|
prompt_file = Path("system_prompt.txt") |
|
|
|
|
|
with prompt_file.open("r", encoding="utf-8") as f: |
|
|
system_prompt = f.read() |
|
|
|
|
|
|
|
|
return SystemMessage(content=system_prompt) |
|
|
|
|
|
|
|
|
def insert_file_into_query(query: str, file_name: str = ""): |
|
|
return f"""{query} - Adjacent files path > {file_name}""" |
|
|
|
|
|
|
|
|
def download_requested_file( |
|
|
task_id: str, |
|
|
question_file: str, |
|
|
endpoint: str = FILES_ENDPOINT, |
|
|
target_dir: str = TARGET_FILES_DIR, |
|
|
): |
|
|
if question_file == "": |
|
|
return |
|
|
|
|
|
target_path = Path(target_dir) |
|
|
if not target_path.exists(): |
|
|
target_path.mkdir(parents=True) |
|
|
|
|
|
file_path = target_path / question_file |
|
|
|
|
|
request = requests.get( |
|
|
f"{endpoint}/files/{task_id}", timeout=30, allow_redirects=True |
|
|
) |
|
|
with file_path.open("wb") as file_: |
|
|
file_.write(request.content) |
|
|
return file_path |
|
|
|