FD900 commited on
Commit
713b432
·
verified ·
1 Parent(s): 713e53f

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +31 -28
agent.py CHANGED
@@ -1,39 +1,42 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
 
2
 
3
- # Load FLAN-T5 base model
4
- model_name = "google/flan-t5-base"
5
- tokenizer = AutoTokenizer.from_pretrained(model_name)
6
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
7
- generator = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
8
 
9
- # GAIA system prompt
10
  system_prompt = (
11
- "You are a general AI assistant. I will ask you a question. Report your thoughts, "
12
- "and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. "
13
- "YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of "
14
- "numbers and/or strings. If you are asked for a number, don't use comma to write your number "
15
- "neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, "
16
- "don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. "
17
- "If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list "
18
- "is a number or a string.\n"
19
  )
20
 
 
 
 
 
 
21
  class BasicAgent:
22
  def __init__(self):
23
- print("Flan-T5 GAIA agent initialized.")
24
 
25
  def __call__(self, question: str) -> str:
26
- prompt = system_prompt + "\nQuestion: " + question
 
 
 
 
 
 
 
 
27
  try:
28
- result = generator(prompt, max_length=256, do_sample=False)[0]['generated_text']
 
 
 
 
 
 
 
 
29
  except Exception as e:
30
- return f"ERROR: {e}"
31
-
32
- # Extract FINAL ANSWER
33
- final_answer = "None"
34
- if "FINAL ANSWER:" in result:
35
- final_answer = result.split("FINAL ANSWER:")[-1].strip()
36
- else:
37
- final_answer = result.strip()
38
-
39
- return final_answer
 
1
+ import os
2
+ import requests
3
 
4
+ API_URL = "https://api-inference.huggingface.co/models/google/flan-t5-base"
5
+ HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
6
 
 
7
  system_prompt = (
8
+ "You are a helpful assistant participating in the GAIA benchmark. "
9
+ "Always return direct, factual answers with no explanation. Output only the final answer."
 
 
 
 
 
 
10
  )
11
 
12
+ headers = {
13
+ "Authorization": f"Bearer {HF_TOKEN}",
14
+ "Content-Type": "application/json"
15
+ }
16
+
17
  class BasicAgent:
18
  def __init__(self):
19
+ print("Flan-T5 Agent initialized using Hugging Face API")
20
 
21
  def __call__(self, question: str) -> str:
22
+ prompt = f"{system_prompt}\n\nQuestion:\n{question}\n\nAnswer:"
23
+ payload = {
24
+ "inputs": prompt,
25
+ "parameters": {
26
+ "temperature": 0.0,
27
+ "max_new_tokens": 128
28
+ }
29
+ }
30
+
31
  try:
32
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
33
+ response.raise_for_status()
34
+ data = response.json()
35
+ if isinstance(data, list) and "generated_text" in data[0]:
36
+ return data[0]["generated_text"].strip()
37
+ elif "generated_text" in data[0]:
38
+ return data[0]["generated_text"].strip()
39
+ else:
40
+ return str(data)
41
  except Exception as e:
42
+ return f"AGENT ERROR: {e}"