Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -23,7 +23,7 @@ configure(api_key=os.getenv("GOOGLE_API_KEY"))
|
|
| 23 |
#logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
# --- Model Configuration ---
|
| 26 |
-
GEMINI_MODEL_NAME = "gemini/gemini-
|
| 27 |
OPENAI_MODEL_NAME = "openai/gpt-4o"
|
| 28 |
GROQ_MODEL_NAME = "groq/llama3-70b-8192"
|
| 29 |
DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
|
|
@@ -123,6 +123,46 @@ class WikiContentFetcher(Tool):
|
|
| 123 |
except wiki.exceptions.PageError:
|
| 124 |
return f"'{page_title}' not found."
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
# --- Basic Agent Definition ---
|
| 127 |
class BasicAgent:
|
| 128 |
def __init__(self, provider="deepseek"):
|
|
@@ -137,6 +177,7 @@ class BasicAgent:
|
|
| 137 |
MathSolver(),
|
| 138 |
RiddleSolver(),
|
| 139 |
TextTransformer(),
|
|
|
|
| 140 |
]
|
| 141 |
self.agent = ToolCallingAgent(
|
| 142 |
model=model,
|
|
@@ -146,21 +187,24 @@ class BasicAgent:
|
|
| 146 |
)
|
| 147 |
self.agent.system_prompt = (
|
| 148 |
"""
|
| 149 |
-
You are a
|
| 150 |
-
|
| 151 |
-
If
|
| 152 |
-
If
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
If
|
| 158 |
-
If
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
|
|
|
|
|
|
| 162 |
|
| 163 |
-
|
|
|
|
| 164 |
"""
|
| 165 |
)
|
| 166 |
|
|
@@ -183,7 +227,40 @@ class BasicAgent:
|
|
| 183 |
final_str = result["final_answer"].strip()
|
| 184 |
else:
|
| 185 |
final_str = str(result).strip()
|
| 186 |
-
|
| 187 |
-
return final_str
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
#logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
# --- Model Configuration ---
|
| 26 |
+
GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash"
|
| 27 |
OPENAI_MODEL_NAME = "openai/gpt-4o"
|
| 28 |
GROQ_MODEL_NAME = "groq/llama3-70b-8192"
|
| 29 |
DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
|
|
|
|
| 123 |
except wiki.exceptions.PageError:
|
| 124 |
return f"'{page_title}' not found."
|
| 125 |
|
| 126 |
+
class FileAttachmentQueryTool(Tool):
|
| 127 |
+
name = "run_query_with_file"
|
| 128 |
+
description = """
|
| 129 |
+
Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it.
|
| 130 |
+
This assumes the file is 20MB or less.
|
| 131 |
+
"""
|
| 132 |
+
inputs = {
|
| 133 |
+
"task_id": {
|
| 134 |
+
"type": "string",
|
| 135 |
+
"description": "A unique identifier for the task related to this file, used to download it."
|
| 136 |
+
},
|
| 137 |
+
"mime_type": {
|
| 138 |
+
"type": "string",
|
| 139 |
+
"nullable": True,
|
| 140 |
+
"description": "The MIME type of the file, or the best guess if unknown."
|
| 141 |
+
},
|
| 142 |
+
"user_query": {
|
| 143 |
+
"type": "string",
|
| 144 |
+
"description": "The question to answer about the file."
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
output_type = "string"
|
| 148 |
+
|
| 149 |
+
def forward(self, task_id: str, mime_type: str | None, user_query: str) -> str:
|
| 150 |
+
file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
|
| 151 |
+
file_response = requests.get(file_url)
|
| 152 |
+
if file_response.status_code != 200:
|
| 153 |
+
return f"Failed to download file: {file_response.status_code} - {file_response.text}"
|
| 154 |
+
file_data = file_response.content
|
| 155 |
+
mime_type = mime_type or file_response.headers.get('Content-Type', 'application/octet-stream')
|
| 156 |
+
|
| 157 |
+
from google.generativeai import GenerativeModel
|
| 158 |
+
model = GenerativeModel(self.model_name)
|
| 159 |
+
response = model.generate_content([
|
| 160 |
+
types.Part.from_bytes(data=file_data, mime_type=mime_type),
|
| 161 |
+
user_query
|
| 162 |
+
])
|
| 163 |
+
|
| 164 |
+
return response.text
|
| 165 |
+
|
| 166 |
# --- Basic Agent Definition ---
|
| 167 |
class BasicAgent:
|
| 168 |
def __init__(self, provider="deepseek"):
|
|
|
|
| 177 |
MathSolver(),
|
| 178 |
RiddleSolver(),
|
| 179 |
TextTransformer(),
|
| 180 |
+
FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME),
|
| 181 |
]
|
| 182 |
self.agent = ToolCallingAgent(
|
| 183 |
model=model,
|
|
|
|
| 187 |
)
|
| 188 |
self.agent.system_prompt = (
|
| 189 |
"""
|
| 190 |
+
You are a GAIA benchmark AI assistant. Your sole purpose is to provide exact, minimal answers in the format 'FINAL ANSWER: [ANSWER]' with no additional text, explanations, or comments.
|
| 191 |
+
|
| 192 |
+
- If the answer is a number, use numerals (e.g., '42', not 'forty-two'), without commas or units (e.g., no '$', '%') unless explicitly requested.
|
| 193 |
+
- If the answer is a string, use no articles ('a', 'the'), no abbreviations (e.g., 'New York', not 'NY'), and write digits as text (e.g., 'one', not '1') unless specified.
|
| 194 |
+
- For comma-separated lists, apply the above rules to each element based on whether it's a number or string.
|
| 195 |
+
- Answer as literally as possible, making minimal assumptions and adhering to the question's narrowest interpretation.
|
| 196 |
+
- For videos, analyze the entire content but extract only the precise answer to the query, ignoring irrelevant details.
|
| 197 |
+
- For Wikipedia or search tools, distill results to the minimal correct answer, ignoring extraneous content.
|
| 198 |
+
- If proving something, compute step-by-step internally but output only the final result in the required format.
|
| 199 |
+
- If tool outputs are verbose, extract only the essential answer that satisfies the question.
|
| 200 |
+
- Under no circumstances include explanations, intermediate steps, or text outside the 'FINAL ANSWER: [ANSWER]' format.
|
| 201 |
+
|
| 202 |
+
Example:
|
| 203 |
+
Question: What is 2 + 2?
|
| 204 |
+
Response: FINAL ANSWER: 4
|
| 205 |
|
| 206 |
+
Your response must always be:
|
| 207 |
+
FINAL ANSWER: [ANSWER]
|
| 208 |
"""
|
| 209 |
)
|
| 210 |
|
|
|
|
| 227 |
final_str = result["final_answer"].strip()
|
| 228 |
else:
|
| 229 |
final_str = str(result).strip()
|
|
|
|
|
|
|
| 230 |
|
| 231 |
+
return f"FINAL ANSWER: {final_str}"
|
| 232 |
+
|
| 233 |
+
def evaluate_random_questions(self, csv_path: str = "gaia_qa.csv", sample_size: int = 3, show_steps: bool = True):
|
| 234 |
+
df = pd.read_csv(csv_path)
|
| 235 |
+
if not {"question", "answer"}.issubset(df.columns):
|
| 236 |
+
print("CSV must contain 'question' and 'answer' columns.")
|
| 237 |
+
print("Found columns:", df.columns.tolist())
|
| 238 |
+
return
|
| 239 |
+
samples = df.sample(n=sample_size)
|
| 240 |
+
for _, row in samples.iterrows():
|
| 241 |
+
question = row["question"].strip()
|
| 242 |
+
expected = f"FINAL ANSWER: {str(row['answer']).strip()}"
|
| 243 |
+
result = self(question).strip()
|
| 244 |
+
if show_steps:
|
| 245 |
+
print("---")
|
| 246 |
+
print("Question:", question)
|
| 247 |
+
print("Expected:", expected)
|
| 248 |
+
print("Agent:", result)
|
| 249 |
+
print("Correct:", expected == result)
|
| 250 |
+
else:
|
| 251 |
+
print(f"Q: {question}\nE: {expected}\nA: {result}\n✓: {expected == result}\n")
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
args = sys.argv[1:]
|
| 255 |
+
if not args or args[0] in {"-h", "--help"}:
|
| 256 |
+
print("Usage: python agent.py [question | dev]")
|
| 257 |
+
print(" - Provide a question to get a GAIA-style answer.")
|
| 258 |
+
print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.")
|
| 259 |
+
sys.exit(0)
|
| 260 |
|
| 261 |
+
q = " ".join(args)
|
| 262 |
+
agent = BasicAgent()
|
| 263 |
+
if q == "dev":
|
| 264 |
+
agent.evaluate_random_questions()
|
| 265 |
+
else:
|
| 266 |
+
print(agent(q))
|