FD900 commited on
Commit
a1b5009
·
verified ·
1 Parent(s): 3df5153

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +69 -0
agent.py CHANGED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from smolagents import create_agent_executor, Agent, Task
4
+ from langgraph.prebuilt import ToolExecutor
5
+ from huggingface_hub import InferenceClient
6
+ from tools import *
7
+
8
+ # Load Hugging Face endpoint from environment
9
+ HF_API_URL = os.getenv("HF_ENDPOINT_URL")
10
+ HF_API_TOKEN = os.getenv("HF_TOKEN")
11
+
12
+ if not HF_API_URL or not HF_API_TOKEN:
13
+ raise ValueError("Missing Hugging Face endpoint URL or token.")
14
+
15
+ llm = InferenceClient(
16
+ model=HF_API_URL,
17
+ token=HF_API_TOKEN
18
+ )
19
+
20
+ def run_llm(prompt: str) -> str:
21
+ response = llm.text_generation(
22
+ prompt,
23
+ max_new_tokens=512,
24
+ do_sample=False,
25
+ temperature=0.0,
26
+ return_full_text=False,
27
+ )
28
+ return response.strip()
29
+
30
+ tool_list = [
31
+ GetAttachmentTool(),
32
+ GoogleSearchTool(),
33
+ GoogleSiteSearchTool(),
34
+ ContentRetrieverTool(),
35
+ SpeechRecognitionTool(),
36
+ YouTubeVideoTool(),
37
+ ClassifierTool(),
38
+ ImageToChessBoardFENTool()
39
+ ]
40
+
41
+ # Create tool executor for LangGraph
42
+ tool_executor = ToolExecutor(tool_list)
43
+
44
+ # Create agent instance
45
+ agent = Agent(
46
+ llm=run_llm,
47
+ tools=tool_list,
48
+ )
49
+
50
+ agent_executor = create_agent_executor(
51
+ agent=agent,
52
+ tool_executor=tool_executor,
53
+ stream=False,
54
+ )
55
+
56
+ def load_tasks(metadata_path="metadata.jsonl") -> list[Task]:
57
+ with open(metadata_path, "r") as f:
58
+ tasks = []
59
+ for line in f:
60
+ data = json.loads(line)
61
+ tasks.append(Task(
62
+ task_id=data["question_id"],
63
+ input=data["answer"]
64
+ ))
65
+ return tasks
66
+
67
+ def solve_task(task: Task) -> str:
68
+ result = agent_executor.invoke(task.input)
69
+ return result.get("output", "")