FD900's picture
Update agent.py
5238d4a verified
raw
history blame
3.48 kB
import os
import requests
from typing import Dict, List
from run import run_and_submit_all # Adjust path if needed
class GaiaAgent:
def __init__(self):
self.api_url = os.environ.get("HF_MISTRAL_ENDPOINT")
self.api_key = os.environ.get("HF_TOKEN")
self.model_id = os.environ.get("LLM_MODEL_ID")
assert self.api_url, "❌ HF_MISTRAL_ENDPOINT is missing!"
assert self.api_key, "❌ HF_TOKEN is missing!"
assert self.model_id, "❌ LLM_MODEL_ID is missing!"
print(f"βœ… [INIT] Model ID: {self.model_id}")
print(f"βœ… [INIT] Endpoint: {self.api_url}")
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
def generate(self, prompt: str, stop: List[str] = []) -> str:
print("🧠 [GENERATE] Prompt sent to model:")
print(prompt)
payload = {
"inputs": prompt,
"parameters": {
"temperature": 0.0,
"max_new_tokens": 1024,
"stop": stop,
}
}
try:
response = requests.post(self.api_url, headers=self.headers, json=payload)
response.raise_for_status()
except Exception as e:
print(f"❌ [ERROR] Request failed: {e}")
return "ERROR: Model call failed"
output = response.json()
print(f"βœ… [RESPONSE] Raw output: {output}")
if isinstance(output, dict) and "generated_text" in output:
return output["generated_text"]
elif isinstance(output, list) and "generated_text" in output[0]:
return output[0]["generated_text"]
else:
return str(output)
'''def answer_question(self, question: Dict) -> str:
# Try different keys that might contain the question
q = question.get("question") or question.get("Question") or question.get("input")
if not q:
raise ValueError(f"No question text found in: {question}")
# Use the required system prompt directly
system_prompt = """You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
prompt = f"{system_prompt}\n\nQuestion: {q}\nAnswer:"
return self.generate(prompt).strip()'''
def answer_question(self, question: Dict) -> str:
print("🧐 [DEBUG] Raw question object:", question)
q = question.get("question") or question.get("Question") or question.get("input")
if not q:
raise ValueError(f"No question text found in: {question}")
prompt = f"{system_prompt}\n\nQuestion: {q}\nAnswer:"
return self.generate(prompt).strip()
def run(self):
print("πŸš€ [RUN] Starting submission...")
return run_and_submit_all(self)