FD900 commited on
Commit
8c0dedc
·
verified ·
1 Parent(s): 70f978a

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +34 -58
agent.py CHANGED
@@ -1,69 +1,45 @@
1
- import os
2
  import json
3
- from smolagents import create_agent_executor, Agent, Task
4
- from langgraph.prebuilt import ToolExecutor
5
- from huggingface_hub import InferenceClient
6
- from tools import *
7
-
8
- # Load Hugging Face endpoint from environment
9
- HF_API_URL = os.getenv("HF_ENDPOINT_URL")
10
- HF_API_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
11
 
12
- if not HF_API_URL or not HF_API_TOKEN:
13
- raise ValueError("Missing Hugging Face endpoint URL or token.")
 
 
 
 
 
14
 
15
- llm = InferenceClient(
16
- model=HF_API_URL,
17
- token=HF_API_TOKEN
18
  )
19
 
20
- def run_llm(prompt: str) -> str:
21
- response = llm.text_generation(
22
- prompt,
23
- max_new_tokens=512,
24
- do_sample=False,
25
- temperature=0.0,
26
- return_full_text=False,
27
- )
28
- return response.strip()
29
-
30
- tool_list = [
31
- GetAttachmentTool(),
32
- GoogleSearchTool(),
33
- GoogleSiteSearchTool(),
34
- ContentRetrieverTool(),
35
- SpeechRecognitionTool(),
36
- YouTubeVideoTool(),
37
- ClassifierTool(),
38
- ImageToChessBoardFENTool()
39
  ]
40
 
41
- # Create tool executor for LangGraph
42
- tool_executor = ToolExecutor(tool_list)
43
-
44
- # Create agent instance
45
  agent = Agent(
46
- llm=run_llm,
47
- tools=tool_list,
 
48
  )
49
 
50
- agent_executor = create_agent_executor(
51
- agent=agent,
52
- tool_executor=tool_executor,
53
- stream=False,
54
- )
55
-
56
- def load_tasks(metadata_path="metadata.jsonl") -> list[Task]:
57
- with open(metadata_path, "r") as f:
58
- tasks = []
59
- for line in f:
60
- data = json.loads(line)
61
- tasks.append(Task(
62
- task_id=data["question_id"],
63
- input=data["answer"]
64
- ))
65
- return tasks
66
-
67
  def solve_task(task: Task) -> str:
68
- result = agent_executor.invoke(task.input)
69
- return result.get("output", "")
 
 
1
  import json
2
+ from pathlib import Path
3
+ from smolagents import Agent, Task
4
+ from mistral_hf_wrapper import HuggingFaceModel
5
+ from tools import (
6
+ chess_tools,
7
+ classifier_tool,
8
+ content_retriever_tool,
9
+ get_attachments_tool,
10
+ google_search_tools,
11
+ speech_recognition_tool,
12
+ youtube_video_tool,
13
+ )
14
 
15
+ def load_tasks(path: str = "metadata.jsonl") -> list[Task]:
16
+ tasks = []
17
+ with open(path, "r", encoding="utf-8") as f:
18
+ for line in f:
19
+ task_dict = json.loads(line.strip())
20
+ tasks.append(Task(**task_dict))
21
+ return tasks
22
 
23
+ model = HuggingFaceModel(
24
+ endpoint_url="https://api-inference.huggingface.co/models/mistralai/Mistral-7B-Instruct-v0.1"
 
25
  )
26
 
27
+ toolset = [
28
+ chess_tools.ImageToChessBoardFENTool(),
29
+ classifier_tool.ClassifierTool(),
30
+ content_retriever_tool.ContentRetrieverTool(),
31
+ get_attachments_tool.GetAttachmentTool(),
32
+ google_search_tools.GoogleSearchTool(),
33
+ google_search_tools.GoogleSiteSearchTool(),
34
+ speech_recognition_tool.SpeechRecognitionTool(),
35
+ youtube_video_tool.YouTubeVideoTool(),
 
 
 
 
 
 
 
 
 
 
36
  ]
37
 
 
 
 
 
38
  agent = Agent(
39
+ model=model,
40
+ tools=toolset,
41
+ system_message="You are a helpful assistant solving GAIA benchmark tasks. Return only the final answer, nothing else."
42
  )
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def solve_task(task: Task) -> str:
45
+ return agent.run(task.question, task_id=task.task_id)