kingkaikai commited on
Commit
a075fae
·
verified ·
1 Parent(s): c613356

update by using smolagent

Browse files
Files changed (2) hide show
  1. app.py +9 -8
  2. tools.py +87 -1
app.py CHANGED
@@ -4,8 +4,7 @@ import gradio as gr
4
  import requests
5
  import inspect
6
  import pandas as pd
7
- from agent import SmoalAgent
8
- from tools import search_tool, rag_chain, extract_final_answer
9
 
10
  # --- Constants ---
11
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
@@ -13,7 +12,7 @@ SUBMISSION_FILE = "submission.jsonl"
13
 
14
 
15
  def run_and_submit_all(profile: gr.OAuthProfile | None):
16
- """Fetches all questions, runs the SmoalAgent on them, submits all answers,
17
  and displays the results."""
18
  # --- Determine HF Space Runtime URL and Repo URL ---
19
  space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
@@ -31,7 +30,9 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
31
 
32
  # 1. Instantiate Agent
33
  try:
34
- agent = SmoalAgent()
 
 
35
  except Exception as e:
36
  print(f"Error instantiating agent: {e}")
37
  return f"Error initializing agent: {e}", None
@@ -71,7 +72,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
71
  print(f"Skipping item with missing task_id or question: {item}")
72
  continue
73
  try:
74
- # 使用导入的搜索工具和RAG
75
  search_result = search_tool.run(question_text)
76
  if rag_chain:
77
  response = rag_chain.run(f"{question_text}\nSearch result: {search_result}")
@@ -83,7 +84,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
83
  "reasoning_trace": response
84
  })
85
  else:
86
- submitted_answer = agent(question_text)
87
  answers_payload.append({
88
  "task_id": task_id,
89
  "model_answer": submitted_answer
@@ -163,7 +164,7 @@ def run_and_submit_all(profile: gr.OAuthProfile | None):
163
 
164
  # --- Build Gradio Interface using Blocks ---
165
  with gr.Blocks() as demo:
166
- gr.Markdown("# Smoal Agent Evaluation Runner")
167
  gr.Markdown(
168
  """
169
  **Instructions:**
@@ -213,5 +214,5 @@ if __name__ == "__main__":
213
 
214
  print("-"*(60 + len(" App Starting ")) + "\n")
215
 
216
- print("Launching Gradio Interface for Smoal Agent Evaluation...")
217
  demo.launch(debug=True, share=False)
 
4
  import requests
5
  import inspect
6
  import pandas as pd
7
+ from tools import search_tool, rag_chain, extract_final_answer, initialize_code_agent
 
8
 
9
  # --- Constants ---
10
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
 
12
 
13
 
14
  def run_and_submit_all(profile: gr.OAuthProfile | None):
15
+ """Fetches all questions, runs the CodeAgent on them, submits all answers,
16
  and displays the results."""
17
  # --- Determine HF Space Runtime URL and Repo URL ---
18
  space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
 
30
 
31
  # 1. Instantiate Agent
32
  try:
33
+ agent = initialize_code_agent()
34
+ if not agent:
35
+ raise Exception("Failed to initialize CodeAgent")
36
  except Exception as e:
37
  print(f"Error instantiating agent: {e}")
38
  return f"Error initializing agent: {e}", None
 
72
  print(f"Skipping item with missing task_id or question: {item}")
73
  continue
74
  try:
75
+ # Use imported search tool and RAG chain
76
  search_result = search_tool.run(question_text)
77
  if rag_chain:
78
  response = rag_chain.run(f"{question_text}\nSearch result: {search_result}")
 
84
  "reasoning_trace": response
85
  })
86
  else:
87
+ submitted_answer = agent.run(question_text)
88
  answers_payload.append({
89
  "task_id": task_id,
90
  "model_answer": submitted_answer
 
164
 
165
  # --- Build Gradio Interface using Blocks ---
166
  with gr.Blocks() as demo:
167
+ gr.Markdown("# Code Agent Evaluation Runner")
168
  gr.Markdown(
169
  """
170
  **Instructions:**
 
214
 
215
  print("-"*(60 + len(" App Starting ")) + "\n")
216
 
217
+ print("Launching Gradio Interface for Code Agent Evaluation...")
218
  demo.launch(debug=True, share=False)
tools.py CHANGED
@@ -7,6 +7,7 @@ from langchain.vectorstores import FAISS
7
  from langchain.prompts import PromptTemplate
8
  from datasets import load_dataset
9
  from agent import SmoalAgent
 
10
 
11
  # System prompt for formatting answers
12
  SYSTEM_PROMPT = """
@@ -58,4 +59,89 @@ def extract_final_answer(response):
58
 
59
  # Initialize RAG chain
60
  global rag_chain
61
- rag_chain = load_gaia_and_setup_rag()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from langchain.prompts import PromptTemplate
8
  from datasets import load_dataset
9
  from agent import SmoalAgent
10
+ from smolagents import CodeAgent, DuckDuckGoSearchTool, InferenceClientModel
11
 
12
  # System prompt for formatting answers
13
  SYSTEM_PROMPT = """
 
59
 
60
  # Initialize RAG chain
61
  global rag_chain
62
+ rag_chain = load_gaia_and_setup_rag()
63
+
64
+ # Initialize search tool
65
+ search_tool = DuckDuckGoSearchTool()
66
+
67
+ # Load GAIA dataset and setup RAG
68
+ rag_chain = None
69
+
70
+ def load_gaia_and_setup_rag():
71
+ try:
72
+ from datasets import load_dataset
73
+ # Load GAIA dataset (test split)
74
+ dataset = load_dataset("gaia-benchmark/gaia", split="test")
75
+
76
+ # Extract contexts from dataset
77
+ contexts = [item["context"] for item in dataset if "context" in item and item["context"]]
78
+
79
+ # Create embeddings and vector store
80
+ embeddings = OpenAIEmbeddings()
81
+ vector_store = FAISS.from_texts(contexts, embeddings)
82
+
83
+ # Create retriever
84
+ retriever = vector_store.as_retriever(search_kwargs={"k": 3})
85
+
86
+ # Define prompt template
87
+ SYSTEM_PROMPT = """
88
+ You are a precise QA system. Answer ONLY with the exact answer, no explanations.
89
+ Answers must be in one of these formats:
90
+ - A single number
91
+ - A single string
92
+ - A comma-separated list of numbers or strings
93
+ Do not include any additional text, explanations, or formatting.
94
+ """
95
+
96
+ prompt_template = PromptTemplate(
97
+ template=SYSTEM_PROMPT + "\nContext: {context}\nQuestion: {question}\nAnswer:",
98
+ input_variables=["context", "question"]
99
+ )
100
+
101
+ # Create RAG chain
102
+ global rag_chain
103
+ rag_chain = RetrievalQA.from_chain_type(
104
+ llm=OpenAI(temperature=0),
105
+ chain_type="stuff",
106
+ retriever=retriever,
107
+ chain_type_kwargs={"prompt": prompt_template}
108
+ )
109
+
110
+ print(f"Successfully loaded GAIA dataset and created RAG chain with {len(contexts)} contexts")
111
+ return True
112
+ except Exception as e:
113
+ print(f"Error setting up RAG: {e}")
114
+ return False
115
+
116
+ # Initialize RAG when the module is loaded
117
+ load_gaia_and_setup_rag()
118
+
119
+ # Initialize CodeAgent
120
+ def initialize_code_agent():
121
+ try:
122
+ # Initialize model with environment variables
123
+ model = InferenceClientModel(
124
+ api_key=os.getenv("OPENAI_API_KEY"),
125
+ model_name="gpt-3.5-turbo"
126
+ )
127
+
128
+ # Create agent with search tool
129
+ agent = CodeAgent(
130
+ tools=[search_tool],
131
+ model=model
132
+ )
133
+
134
+ print("CodeAgent initialized successfully")
135
+ return agent
136
+ except Exception as e:
137
+ print(f"Error initializing CodeAgent: {e}")
138
+ return None
139
+
140
+ # Final answer extraction
141
+ def extract_final_answer(text):
142
+ # Use regex to find the final answer pattern
143
+ match = re.search(r'FINAL ANSWER: (.*)', text, re.IGNORECASE)
144
+ if match:
145
+ return match.group(1).strip()
146
+ # If no pattern found, return the text as is (with cleanup)
147
+ return text.strip()