FD900 commited on
Commit
6ece3a3
·
verified ·
1 Parent(s): 9b4ce5f

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +35 -21
agent.py CHANGED
@@ -1,30 +1,44 @@
1
  import os
 
 
2
  import requests
3
- import yaml
4
- from tools import web_search, wikipedia_search, visit_webpage, final_answer
5
 
6
- with open("prompts.yaml", "r") as f:
7
- prompts = yaml.safe_load(f)
8
-
9
- SYSTEM_PROMPT = prompts["system_prompt"]
10
 
11
  class GaiaAgent:
12
  def __init__(self):
13
- self.system_prompt = SYSTEM_PROMPT
14
- self.endpoint_url = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1"
15
  self.headers = {
16
- "Authorization": f"Bearer {os.environ['HF_TOKEN']}",
17
- "Content-Type": "application/json"
 
 
 
 
 
 
 
 
 
 
18
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- def __call__(self, task):
21
- prompt = self.system_prompt + "\nTask:\n" + task
22
- payload = {"inputs": prompt}
23
- try:
24
- response = requests.post(self.endpoint_url, headers=self.headers, json=payload)
25
- output = response.json()
26
- if isinstance(output, list) and "generated_text" in output[0]:
27
- return output[0]["generated_text"]
28
- return output
29
- except Exception as e:
30
- return f"AGENT ERROR: {str(e)}"
 
1
  import os
2
+ import json
3
+ from typing import Dict, List
4
  import requests
 
 
5
 
6
+ from gaia_benchmark.questions import load_questions
7
+ from gaia_benchmark.run import run_and_submit_all
 
 
8
 
9
  class GaiaAgent:
10
  def __init__(self):
11
+ self.api_url = os.environ["HF_MISTRAL_ENDPOINT"] # Your Mistral endpoint
12
+ self.api_key = os.environ["HF_TOKEN"] # Hugging Face token
13
  self.headers = {
14
+ "Authorization": f"Bearer {self.api_key}",
15
+ "Content-Type": "application/json",
16
+ }
17
+
18
+ def generate(self, prompt: str, stop: List[str] = []) -> str:
19
+ payload = {
20
+ "inputs": prompt,
21
+ "parameters": {
22
+ "temperature": 0.0,
23
+ "max_new_tokens": 1024,
24
+ "stop": stop,
25
+ }
26
  }
27
+ response = requests.post(self.api_url, headers=self.headers, json=payload)
28
+ response.raise_for_status()
29
+ output = response.json()
30
+ if isinstance(output, dict) and "generated_text" in output:
31
+ return output["generated_text"]
32
+ if isinstance(output, list):
33
+ return output[0]["generated_text"]
34
+ return str(output)
35
+
36
+ def answer_question(self, question: Dict) -> str:
37
+ question_text = question["question"]
38
+ prompt = f"""You are a helpful agent answering a science question.
39
+ Question: {question_text}
40
+ Answer:"""
41
+ return self.generate(prompt).strip()
42
 
43
+ def run(self):
44
+ run_and_submit_all(self.answer_question)