bobobert4 commited on
Commit
728aee3
·
1 Parent(s): 9277999

fix: add function to recover task files from endpoint

Browse files
Files changed (2) hide show
  1. agent.py +39 -7
  2. app.py +15 -2
agent.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  from pathlib import Path
3
  from typing import TypedDict, Annotated
4
  from uuid import uuid4
 
5
 
6
  from langgraph.graph.message import add_messages
7
  from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
@@ -9,6 +10,7 @@ from langgraph.prebuilt import ToolNode
9
  from langgraph.graph import START, StateGraph
10
  from langgraph.checkpoint.memory import MemorySaver
11
  from langgraph.prebuilt import tools_condition
 
12
  # from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
13
  from langchain_google_genai import ChatGoogleGenerativeAI
14
  from langchain_core.rate_limiters import InMemoryRateLimiter
@@ -27,6 +29,11 @@ from tools import basic_tools
27
  # Google's chat interface
28
  RPM = os.environ.get("AGENT_MODEL_RPM", 9)
29
  TPM = os.environ.get("AGENT_MODEL_TPM", 250000)
 
 
 
 
 
30
  limiter = InMemoryRateLimiter(
31
  requests_per_second=(RPM / 60),
32
  check_every_n_seconds=(RPM / 70),
@@ -34,7 +41,7 @@ limiter = InMemoryRateLimiter(
34
  )
35
  chat = ChatGoogleGenerativeAI(
36
  # model="gemini-2.0-flash-lite",
37
- model=os.environ.get("AGENT_MODEL","gemini-2.5-flash-preview-04-17"),
38
  temperature=os.environ.get("AGENT_MODEL_TEMP", 0.25),
39
  max_retries=os.environ.get("AGENT_MODEL_RETRIES", 2),
40
  verbose=True,
@@ -44,15 +51,18 @@ chat = ChatGoogleGenerativeAI(
44
  chat_with_tools = chat.bind_tools(basic_tools)
45
  memory = MemorySaver()
46
 
 
47
  # Generate the AgentState and Agent graph
48
  class AgentState(TypedDict):
49
  messages: Annotated[list[AnyMessage], add_messages]
50
 
 
51
  def assistant(state: AgentState):
52
  return {
53
  "messages": [chat_with_tools.invoke(state["messages"])],
54
  }
55
 
 
56
  ## The graph
57
  builder = StateGraph(AgentState)
58
 
@@ -70,12 +80,10 @@ builder.add_conditional_edges(
70
  )
71
  builder.add_edge("tools", "assistant")
72
 
 
73
  def create_config():
74
- return {
75
- "configurable": {
76
- "thread_id": str(uuid4())
77
- }
78
- }
79
 
80
  def get_system_prompt(prompt_file: Path = None):
81
  if prompt_file is None:
@@ -83,9 +91,33 @@ def get_system_prompt(prompt_file: Path = None):
83
  # load the system prompt from the file
84
  with prompt_file.open("r", encoding="utf-8") as f:
85
  system_prompt = f.read()
86
-
87
  # System message
88
  return SystemMessage(content=system_prompt)
89
 
 
90
  def insert_file_into_query(query: str, file_name: str = ""):
91
  return f"""{query} - Adjacent files path > {file_name}"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from pathlib import Path
3
  from typing import TypedDict, Annotated
4
  from uuid import uuid4
5
+ import requests
6
 
7
  from langgraph.graph.message import add_messages
8
  from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
 
10
  from langgraph.graph import START, StateGraph
11
  from langgraph.checkpoint.memory import MemorySaver
12
  from langgraph.prebuilt import tools_condition
13
+
14
  # from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
15
  from langchain_google_genai import ChatGoogleGenerativeAI
16
  from langchain_core.rate_limiters import InMemoryRateLimiter
 
29
  # Google's chat interface
30
  RPM = os.environ.get("AGENT_MODEL_RPM", 9)
31
  TPM = os.environ.get("AGENT_MODEL_TPM", 250000)
32
+ FILES_ENDPOINT = os.environ.get(
33
+ "FILES_ENDPOINT", "https://agents-course-unit4-scoring.hf.space"
34
+ )
35
+ TARGET_FILES_DIR = os.environ.get("TARGET_FILES_DIR", "/tmp/task_file")
36
+
37
  limiter = InMemoryRateLimiter(
38
  requests_per_second=(RPM / 60),
39
  check_every_n_seconds=(RPM / 70),
 
41
  )
42
  chat = ChatGoogleGenerativeAI(
43
  # model="gemini-2.0-flash-lite",
44
+ model=os.environ.get("AGENT_MODEL", "gemini-2.5-flash-preview-04-17"),
45
  temperature=os.environ.get("AGENT_MODEL_TEMP", 0.25),
46
  max_retries=os.environ.get("AGENT_MODEL_RETRIES", 2),
47
  verbose=True,
 
51
  chat_with_tools = chat.bind_tools(basic_tools)
52
  memory = MemorySaver()
53
 
54
+
55
  # Generate the AgentState and Agent graph
56
  class AgentState(TypedDict):
57
  messages: Annotated[list[AnyMessage], add_messages]
58
 
59
+
60
  def assistant(state: AgentState):
61
  return {
62
  "messages": [chat_with_tools.invoke(state["messages"])],
63
  }
64
 
65
+
66
  ## The graph
67
  builder = StateGraph(AgentState)
68
 
 
80
  )
81
  builder.add_edge("tools", "assistant")
82
 
83
+
84
  def create_config():
85
+ return {"configurable": {"thread_id": str(uuid4())}}
86
+
 
 
 
87
 
88
  def get_system_prompt(prompt_file: Path = None):
89
  if prompt_file is None:
 
91
  # load the system prompt from the file
92
  with prompt_file.open("r", encoding="utf-8") as f:
93
  system_prompt = f.read()
94
+
95
  # System message
96
  return SystemMessage(content=system_prompt)
97
 
98
+
99
  def insert_file_into_query(query: str, file_name: str = ""):
100
  return f"""{query} - Adjacent files path > {file_name}"""
101
+
102
+
103
+ def download_requested_file(
104
+ task_id: str,
105
+ question_file: str,
106
+ endpoint: str = FILES_ENDPOINT,
107
+ target_dir: str = TARGET_FILES_DIR,
108
+ ):
109
+ if question_file == "":
110
+ return
111
+
112
+ target_path = Path(target_dir)
113
+ if not target_path.exists():
114
+ target_path.mkdir(parents=True)
115
+ # Create path
116
+ file_path = target_path / question_file
117
+ # Download file
118
+ request = requests.get(
119
+ f"{endpoint}/files/{task_id}", timeout=30, allow_redirects=True
120
+ )
121
+ with file_path.open("wb") as file_:
122
+ file_.write(request.content)
123
+ return file_path
app.py CHANGED
@@ -4,7 +4,7 @@ import requests
4
  import inspect
5
  import pandas as pd
6
  from pprint import pprint
7
- from agent import builder, HumanMessage, memory, create_config, get_system_prompt, insert_file_into_query
8
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
@@ -98,10 +98,19 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
98
  for item in questions_data:
99
  task_id = item.get("task_id")
100
  question_text = item.get("question")
101
- question_file = item.get("file_name", "")
102
  if not task_id or question_text is None:
103
  print(f"Skipping item with missing task_id or question: {item}")
104
  continue
 
 
 
 
 
 
 
 
 
 
105
  try:
106
  submitted_answer = run_agent(question_text, question_file)
107
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
@@ -109,6 +118,10 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
109
  except Exception as e:
110
  print(f"Error running agent on task {task_id}: {e}")
111
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
 
 
 
 
112
 
113
  if not answers_payload:
114
  print("Agent did not produce any answers to submit.")
 
4
  import inspect
5
  import pandas as pd
6
  from pprint import pprint
7
+ from agent import builder, HumanMessage, memory, create_config, get_system_prompt, insert_file_into_query, download_requested_file
8
 
9
  # (Keep Constants as is)
10
  # --- Constants ---
 
98
  for item in questions_data:
99
  task_id = item.get("task_id")
100
  question_text = item.get("question")
 
101
  if not task_id or question_text is None:
102
  print(f"Skipping item with missing task_id or question: {item}")
103
  continue
104
+
105
+ # Get task file if any
106
+ question_file = item.get("file_name", "")
107
+ try:
108
+ file_path = download_requested_file(task_id, question_file)
109
+ except Exception as err:
110
+ print(f"Could not download file ({question_file}) -> {err}")
111
+ file_path = None
112
+
113
+ question_file = str(file_path) if file_path is not None else ""
114
  try:
115
  submitted_answer = run_agent(question_text, question_file)
116
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
 
118
  except Exception as e:
119
  print(f"Error running agent on task {task_id}: {e}")
120
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
121
+
122
+ # Clean files
123
+ if file_path is not None:
124
+ os.remove(file_path)
125
 
126
  if not answers_payload:
127
  print("Agent did not produce any answers to submit.")