i-dhilip commited on
Commit
40ddda8
·
verified ·
1 Parent(s): 02d16b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +375 -0
app.py CHANGED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent with Gradio Interface"""
2
+ import os
3
+ import gradio as gr
4
+ import requests
5
+ import pandas as pd
6
+ from dotenv import load_dotenv
7
+ from langgraph.graph import START, StateGraph, MessagesState
8
+ from langgraph.prebuilt import tools_condition, ToolNode
9
+ from langchain_google_genai import ChatGoogleGenerativeAI
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
12
+ from langchain_community.vectorstores import Chroma
13
+ from langchain_core.messages import SystemMessage, HumanMessage
14
+ from langchain_core.tools import tool
15
+ from langchain.tools.retriever import create_retriever_tool
16
+ from langchain_community.embeddings import HuggingFaceEmbeddings
17
+
18
+ # Load environment variables
19
+ load_dotenv()
20
+
21
+ # Tool Definitions
22
+ @tool
23
+ def multiply(a: int, b: int) -> int:
24
+ """Multiply two numbers."""
25
+ return a * b
26
+
27
+ @tool
28
+ def add(a: int, b: int) -> int:
29
+ """Add two numbers."""
30
+ return a + b
31
+
32
+ @tool
33
+ def subtract(a: int, b: int) -> int:
34
+ """Subtract two numbers."""
35
+ return a - b
36
+
37
+ @tool
38
+ def divide(a: int, b: int) -> int:
39
+ """Divide two numbers."""
40
+ if b == 0:
41
+ raise ValueError("Cannot divide by zero.")
42
+ return a / b
43
+
44
+ @tool
45
+ def modulus(a: int, b: int) -> int:
46
+ """Get the modulus of two numbers."""
47
+ return a % b
48
+
49
+ @tool
50
+ def wiki_search(query: str) -> str:
51
+ """Search Wikipedia for a query and return maximum 2 results."""
52
+ try:
53
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
54
+ formatted_search_docs = "\n\n---\n\n".join(
55
+ [f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content}\n</Document>'
56
+ for doc in search_docs])
57
+ return {"wiki_results": formatted_search_docs}
58
+ except Exception as e:
59
+ return {"wiki_results": f"Error: {str(e)}"}
60
+
61
+ @tool
62
+ def web_search(query: str) -> str:
63
+ """Search Tavily for a query and return maximum 3 results."""
64
+ try:
65
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
66
+ formatted_search_docs = "\n\n---\n\n".join(
67
+ [f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content}\n</Document>'
68
+ for doc in search_docs])
69
+ return {"web_results": formatted_search_docs}
70
+ except Exception as e:
71
+ return {"web_results": f"Error: {str(e)}"}
72
+
73
+ @tool
74
+ def arvix_search(query: str) -> str:
75
+ """Search Arxiv for a query and return maximum 3 results."""
76
+ try:
77
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
78
+ formatted_search_docs = "\n\n---\n\n".join(
79
+ [f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content[:1000]}\n</Document>'
80
+ for doc in search_docs])
81
+ return {"arvix_results": formatted_search_docs}
82
+ except Exception as e:
83
+ return {"arvix_results": f"Error: {str(e)}"}
84
+
85
+ # System Prompt Setup
86
+ try:
87
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
88
+ system_prompt = f.read()
89
+ sys_msg = SystemMessage(content=system_prompt)
90
+ except FileNotFoundError:
91
+ sys_msg = SystemMessage(content="Default system prompt")
92
+
93
+ # Vector Store Setup with error handling
94
+ try:
95
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
96
+ vector_store = Chroma(
97
+ collection_name="documents",
98
+ embedding_function=embeddings,
99
+ persist_directory="./chroma_db"
100
+ )
101
+ except Exception as e:
102
+ print(f"Error initializing vector store: {e}")
103
+ vector_store = None
104
+
105
+ # Tool Configuration with null check
106
+ tools = [
107
+ multiply, add, subtract, divide, modulus,
108
+ wiki_search, web_search, arvix_search
109
+ ]
110
+
111
+ if vector_store:
112
+ tools.append(
113
+ create_retriever_tool(
114
+ vector_store.as_retriever(),
115
+ name="Question Search",
116
+ description="Retrieves similar questions from vector store"
117
+ )
118
+ )
119
+ else:
120
+ print("Warning: Vector store not initialized. Question Search tool disabled.")
121
+
122
+ # Model Configuration
123
+ MODEL_REGISTRY = {
124
+ "gemini-2.0-flash": {
125
+ "model": "gemini-2.0-flash",
126
+ "temperature": 0,
127
+ "max_tokens": 2048
128
+ }
129
+ }
130
+
131
+ def get_llm(model_name: str = "gemini-2.0-flash"):
132
+ """Initialize LLM with error handling"""
133
+ config = MODEL_REGISTRY.get(model_name, MODEL_REGISTRY["gemini-2.0-flash"])
134
+ try:
135
+ return ChatGoogleGenerativeAI(
136
+ model=config["model"],
137
+ temperature=config["temperature"],
138
+ max_tokens=config["max_tokens"]
139
+ )
140
+ except Exception as e:
141
+ print(f"Error initializing {model_name}: {e}")
142
+ return None
143
+
144
+ # Updated Graph Builder Function
145
+ def build_graph():
146
+ """Build LangGraph agent workflow with Gemini model"""
147
+ primary_llm = get_llm("gemini-2.0-flash")
148
+
149
+ llms = [llm for llm in [primary_llm] if llm is not None]
150
+
151
+ if not llms:
152
+ raise RuntimeError("Failed to initialize any LLM")
153
+
154
+ current_llm_index = 0
155
+
156
+ def assistant(state: MessagesState):
157
+ nonlocal current_llm_index
158
+ for attempt in range(len(llms)):
159
+ try:
160
+ llm = llms[current_llm_index]
161
+ llm_with_tools = llm.bind_tools(tools)
162
+ response = llm_with_tools.invoke(state["messages"])
163
+ current_llm_index = (current_llm_index + 1) % len(llms) # Rotate LLMs
164
+ return {"messages": [response]}
165
+ except Exception as e:
166
+ print(f"Model {llms[current_llm_index].model} failed: {e}")
167
+ current_llm_index = (current_llm_index + 1) % len(llms)
168
+ if attempt == len(llms) - 1:
169
+ error_msg = HumanMessage(content=f"All models failed: {str(e)}")
170
+ return {"messages": [error_msg]}
171
+
172
+ def retriever(state: MessagesState):
173
+ try:
174
+ if vector_store:
175
+ similar_questions = vector_store.similarity_search(
176
+ state["messages"][0].content,
177
+ k=1
178
+ )
179
+ example_content = "Similar question reference: \n\n" + \
180
+ (similar_questions[0].page_content if similar_questions
181
+ else "No similar questions found")
182
+ else:
183
+ example_content = "Vector store not available"
184
+
185
+ return {"messages": [sys_msg] + state["messages"] + [HumanMessage(content=example_content)]}
186
+ except Exception as e:
187
+ error_msg = HumanMessage(content=f"Retrieval error: {str(e)}")
188
+ return {"messages": [error_msg]}
189
+
190
+ builder = StateGraph(MessagesState)
191
+ builder.add_node("retriever", retriever)
192
+ builder.add_node("assistant", assistant)
193
+ builder.add_node("tools", ToolNode(tools))
194
+
195
+ builder.add_edge(START, "retriever")
196
+ builder.add_edge("retriever", "assistant")
197
+ builder.add_conditional_edges("assistant", tools_condition)
198
+ builder.add_edge("tools", "assistant")
199
+
200
+ return builder.compile()
201
+
202
+
203
+ class BasicAgent:
204
+ """LangGraph Agent Interface"""
205
+ def __init__(self):
206
+ self.graph = build_graph()
207
+
208
+ def __call__(self, question: str) -> str:
209
+ try:
210
+ messages = [HumanMessage(content=question)]
211
+ result = self.graph.invoke({"messages": messages})
212
+ last_message = result['messages'][-1].content
213
+
214
+ # Improved content extraction
215
+ if "FINAL ANSWER: " in last_message:
216
+ answer_part = last_message.split("FINAL ANSWER: ")[-1].strip()
217
+ if answer_part.endswith('"}'):
218
+ return answer_part[:-2].strip()
219
+ return answer_part
220
+ elif "Answer:" in last_message:
221
+ answer_part = last_message.split("Answer:")[-1].strip()
222
+ if answer_part.endswith('"}'):
223
+ return answer_part[:-2].strip()
224
+ return answer_part
225
+ return last_message
226
+ except Exception as e:
227
+ return f"Agent processing error: {str(e)}"
228
+
229
+
230
+
231
+
232
+
233
+
234
+ # Updated Agent Class
235
+ # class BasicAgent:
236
+ # """LangGraph Agent Interface"""
237
+ # def __init__(self):
238
+ # self.graph = build_graph()
239
+
240
+ # def __call__(self, question: str) -> str:
241
+ # try:
242
+ # messages = [HumanMessage(content=question)]
243
+ # result = self.graph.invoke({"messages": messages})
244
+ # last_message = result['messages'][-1].content
245
+
246
+ # # Improved content extraction
247
+ # if "FINAL ANSWER: " in last_message:
248
+ # return last_message.split("FINAL ANSWER: ")[-1].strip()
249
+ # elif "Answer:" in last_message:
250
+ # return last_message.split("Answer:")[-1].strip()
251
+ # return last_message
252
+ # except Exception as e:
253
+ # return f"Agent processing error: {str(e)}"
254
+
255
+
256
+
257
+
258
+ # Gradio Interface Functions
259
+ def run_and_submit_all(profile: gr.OAuthProfile | None):
260
+ """Evaluation runner function"""
261
+ if not profile:
262
+ return "Please Login to Hugging Face with the button.", None
263
+
264
+ space_id = os.getenv("SPACE_ID")
265
+ api_url = "https://agents-course-unit4-scoring.hf.space"
266
+ username = profile.username
267
+ results_log = []
268
+
269
+ try:
270
+ agent = BasicAgent()
271
+ agent_code = f"https://huggingface.co/spaces/ {space_id}/tree/main"
272
+
273
+ # Fetch questions
274
+ response = requests.get(f"{api_url}/questions", timeout=15)
275
+ response.raise_for_status()
276
+ questions_data = response.json()
277
+
278
+ # Process questions
279
+ answers_payload = []
280
+ for item in questions_data:
281
+ task_id = item.get("task_id")
282
+ question_text = item.get("question")
283
+ if not task_id or not question_text:
284
+ continue
285
+
286
+ try:
287
+ answer = agent(question_text)
288
+ answers_payload.append({
289
+ "task_id": task_id,
290
+ "submitted_answer": answer
291
+ })
292
+ results_log.append({
293
+ "Task ID": task_id,
294
+ "Question": question_text,
295
+ "Submitted Answer": answer
296
+ })
297
+ except Exception as e:
298
+ results_log.append({
299
+ "Task ID": task_id,
300
+ "Question": question_text,
301
+ "Submitted Answer": f"AGENT ERROR: {e}"
302
+ })
303
+
304
+ # Submit answers
305
+ submission_data = {
306
+ "username": username.strip(),
307
+ "agent_code": agent_code,
308
+ "answers": answers_payload
309
+ }
310
+
311
+ response = requests.post(f"{api_url}/submit", json=submission_data, timeout=60)
312
+ response.raise_for_status()
313
+ result_data = response.json()
314
+
315
+ final_status = (
316
+ f"Submission Successful!\nOverall Score: {result_data.get('score', 'N/A')}%\n"
317
+ f"Correct: {result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')}\n"
318
+ f"Message: {result_data.get('message', 'No message')}"
319
+ )
320
+ return final_status, pd.DataFrame(results_log)
321
+
322
+ except Exception as e:
323
+ return f"Error: {str(e)}", pd.DataFrame(results_log)
324
+
325
+ # Gradio UI Setup
326
+ with gr.Blocks() as demo:
327
+ gr.Markdown("# Basic Agent Evaluation Runner")
328
+ gr.Markdown(
329
+ """
330
+ **Instructions:**
331
+ 1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ...
332
+ 2. Log in to your Hugging Face account using the button below. This uses your HF username for submission.
333
+ 3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
334
+ ---
335
+ **Disclaimers:**
336
+ Once clicking on the "submit button, it can take quite some time (this is the time for the agent to go through all the questions).
337
+ This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance, for the delay process of the submit button, a solution could be to cache the answers and submit in a separate action or even to answer the questions in async.
338
+ """
339
+ )
340
+
341
+ gr.LoginButton()
342
+
343
+ run_button = gr.Button("Run Evaluation & Submit All Answers")
344
+
345
+ status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
346
+ results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
347
+
348
+ run_button.click(
349
+ fn=run_and_submit_all,
350
+ outputs=[status_output, results_table]
351
+ )
352
+
353
+ if __name__ == "__main__":
354
+ print("\n" + "-"*30 + " App Starting " + "-"*30)
355
+ # Check for SPACE_HOST and SPACE_ID at startup for information
356
+ space_host_startup = os.getenv("SPACE_HOST")
357
+ space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
358
+
359
+ if space_host_startup:
360
+ print(f"✅ SPACE_HOST found: {space_host_startup}")
361
+ print(f" Runtime URL should be: https://{space_host_startup}.hf.space")
362
+ else:
363
+ print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
364
+
365
+ if space_id_startup: # Print repo URLs if SPACE_ID is found
366
+ print(f"✅ SPACE_ID found: {space_id_startup}")
367
+ print(f" Repo URL: https://huggingface.co/spaces/ {space_id_startup}")
368
+ print(f" Repo Tree URL: https://huggingface.co/spaces/ {space_id_startup}/tree/main")
369
+ else:
370
+ print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")
371
+
372
+ print("-"*(60 + len(" App Starting ")) + "\n")
373
+
374
+ print("Launching Gradio Interface for Basic Agent Evaluation...")
375
+ demo.launch(debug=True, share=False)