FD900's picture
Update agent.py
6ece3a3 verified
raw
history blame
1.49 kB
import os
import json
from typing import Dict, List
import requests
from gaia_benchmark.questions import load_questions
from gaia_benchmark.run import run_and_submit_all
class GaiaAgent:
def __init__(self):
self.api_url = os.environ["HF_MISTRAL_ENDPOINT"] # Your Mistral endpoint
self.api_key = os.environ["HF_TOKEN"] # Hugging Face token
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def generate(self, prompt: str, stop: List[str] = []) -> str:
payload = {
"inputs": prompt,
"parameters": {
"temperature": 0.0,
"max_new_tokens": 1024,
"stop": stop,
}
}
response = requests.post(self.api_url, headers=self.headers, json=payload)
response.raise_for_status()
output = response.json()
if isinstance(output, dict) and "generated_text" in output:
return output["generated_text"]
if isinstance(output, list):
return output[0]["generated_text"]
return str(output)
def answer_question(self, question: Dict) -> str:
question_text = question["question"]
prompt = f"""You are a helpful agent answering a science question.
Question: {question_text}
Answer:"""
return self.generate(prompt).strip()
def run(self):
run_and_submit_all(self.answer_question)