Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files
agent.py
ADDED
|
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
+
# --- Supabase Setup (only if credentials are provided) ---
|
| 7 |
+
supabase_url = os.getenv("SUPABASE_URL")
|
| 8 |
+
supabase_key = os.getenv("SUPABASE_SERVICE_KEY") or os.getenv("SUPABASE_KEY")
|
| 9 |
+
|
| 10 |
+
if supabase_url and supabase_key:
|
| 11 |
+
from supabase.client import Client, create_client
|
| 12 |
+
from langchain_community.vectorstores import SupabaseVectorStore
|
| 13 |
+
from langchain.tools.retriever import create_retriever_tool
|
| 14 |
+
from langchain_openai import OpenAIEmbeddings
|
| 15 |
+
supabase: Client = create_client(supabase_url, supabase_key)
|
| 16 |
+
else:
|
| 17 |
+
supabase = None
|
| 18 |
+
|
| 19 |
+
# --- Standard Imports ---
|
| 20 |
+
from langgraph.graph import START, StateGraph, MessagesState
|
| 21 |
+
from langgraph.prebuilt import tools_condition, ToolNode
|
| 22 |
+
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
| 23 |
+
from langchain_core.tools import tool
|
| 24 |
+
|
| 25 |
+
# OpenAI LLM
|
| 26 |
+
from langchain_openai import ChatOpenAI
|
| 27 |
+
|
| 28 |
+
# Optional document loaders
|
| 29 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 30 |
+
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
|
| 31 |
+
|
| 32 |
+
# --- Simple Math Tools ---
|
| 33 |
+
@tool
|
| 34 |
+
def multiply(a: int, b: int) -> int:
|
| 35 |
+
"""Multiply two integers and return the result"""
|
| 36 |
+
return a * b
|
| 37 |
+
|
| 38 |
+
@tool
|
| 39 |
+
def add(a: int, b: int) -> int:
|
| 40 |
+
"""Add two integers and return the sum"""
|
| 41 |
+
return a + b
|
| 42 |
+
|
| 43 |
+
@tool
|
| 44 |
+
def subtract(a: int, b: int) -> int:
|
| 45 |
+
"""Subtract the second integer from the first and return the difference"""
|
| 46 |
+
return a - b
|
| 47 |
+
|
| 48 |
+
@tool
|
| 49 |
+
def divide(a: int, b: int) -> float:
|
| 50 |
+
"""Divide the first integer by the second and return the quotient"""
|
| 51 |
+
if b == 0:
|
| 52 |
+
raise ValueError("Cannot divide by zero.")
|
| 53 |
+
return a / b
|
| 54 |
+
|
| 55 |
+
@tool
|
| 56 |
+
def modulus(a: int, b: int) -> int:
|
| 57 |
+
"""Return the modulus of dividing the first integer by the second"""
|
| 58 |
+
return a % b
|
| 59 |
+
|
| 60 |
+
# --- Search Tools ---
|
| 61 |
+
@tool
|
| 62 |
+
def wiki_search(query: str) -> str:
|
| 63 |
+
"""Search Wikipedia for the query and return up to 2 documents"""
|
| 64 |
+
try:
|
| 65 |
+
docs = WikipediaLoader(query=query, load_max_docs=2).load()
|
| 66 |
+
return "\n\n---\n\n".join(
|
| 67 |
+
f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content}' for doc in docs
|
| 68 |
+
)
|
| 69 |
+
except Exception as e:
|
| 70 |
+
return f"Wikipedia search failed: {str(e)}"
|
| 71 |
+
|
| 72 |
+
@tool
|
| 73 |
+
def web_search(query: str) -> str:
|
| 74 |
+
"""Search the web using Tavily and return up to 3 results"""
|
| 75 |
+
try:
|
| 76 |
+
tavily_api_key = os.getenv("TAVILY_API_KEY")
|
| 77 |
+
if not tavily_api_key:
|
| 78 |
+
return "Web search unavailable: TAVILY_API_KEY not configured"
|
| 79 |
+
|
| 80 |
+
search_tool = TavilySearchResults(max_results=3, api_key=tavily_api_key)
|
| 81 |
+
docs = search_tool.invoke({"query": query})
|
| 82 |
+
return "\n\n---\n\n".join(
|
| 83 |
+
f'<Document source="{doc.get("url", "Unknown")}"/>\n{doc.get("content", "")}' for doc in docs
|
| 84 |
+
)
|
| 85 |
+
except Exception as e:
|
| 86 |
+
return f"Web search failed: {str(e)}"
|
| 87 |
+
|
| 88 |
+
@tool
|
| 89 |
+
def arxiv_search(query: str) -> str:
|
| 90 |
+
"""Search Arxiv for the query and return up to 3 documents"""
|
| 91 |
+
try:
|
| 92 |
+
docs = ArxivLoader(query=query, load_max_docs=3).load()
|
| 93 |
+
return "\n\n---\n\n".join(
|
| 94 |
+
f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content[:1000]}' for doc in docs
|
| 95 |
+
)
|
| 96 |
+
except Exception as e:
|
| 97 |
+
return f"Arxiv search failed: {str(e)}"
|
| 98 |
+
|
| 99 |
+
# --- Assemble Tools List ---
|
| 100 |
+
tools = [multiply, add, subtract, divide, modulus, wiki_search, web_search, arxiv_search]
|
| 101 |
+
|
| 102 |
+
# If supabase is configured, add retriever tool
|
| 103 |
+
if supabase:
|
| 104 |
+
try:
|
| 105 |
+
embeddings = OpenAIEmbeddings()
|
| 106 |
+
vector_store = SupabaseVectorStore(
|
| 107 |
+
client=supabase,
|
| 108 |
+
embedding=embeddings,
|
| 109 |
+
table_name="documents",
|
| 110 |
+
query_name="match_documents_langchain",
|
| 111 |
+
)
|
| 112 |
+
retriever_tool = create_retriever_tool(
|
| 113 |
+
retriever=vector_store.as_retriever(),
|
| 114 |
+
name="Question Search",
|
| 115 |
+
description="Retrieve similar questions from the vector store",
|
| 116 |
+
)
|
| 117 |
+
tools.append(retriever_tool)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
print(f"Could not initialize Supabase retriever: {e}")
|
| 120 |
+
|
| 121 |
+
# --- Load System Prompt ---
|
| 122 |
+
def load_system_prompt():
|
| 123 |
+
"""Load system prompt with fallback"""
|
| 124 |
+
try:
|
| 125 |
+
with open("system_prompt.txt", "r", encoding="utf-8") as f:
|
| 126 |
+
return SystemMessage(content=f.read())
|
| 127 |
+
except FileNotFoundError:
|
| 128 |
+
# Fallback system prompt
|
| 129 |
+
default_prompt = """You are a helpful AI assistant with access to various tools including:
|
| 130 |
+
- Math operations (add, subtract, multiply, divide, modulus)
|
| 131 |
+
- Search capabilities (Wikipedia, Arxiv, web search via Tavily)
|
| 132 |
+
- Information retrieval
|
| 133 |
+
|
| 134 |
+
Use these tools when appropriate to answer questions accurately and helpfully. When performing calculations, always use the provided math tools. When users ask for information that might require current data or research, use the appropriate search tools.
|
| 135 |
+
|
| 136 |
+
Be concise but thorough in your responses. If you use a tool, explain what you found or calculated."""
|
| 137 |
+
return SystemMessage(content=default_prompt)
|
| 138 |
+
|
| 139 |
+
sys_msg = load_system_prompt()
|
| 140 |
+
|
| 141 |
+
# --- Graph Builder (OpenAI) ---
|
| 142 |
+
def build_graph():
|
| 143 |
+
"""
|
| 144 |
+
Build and return a StateGraph using OpenAI ChatGPT with tools.
|
| 145 |
+
"""
|
| 146 |
+
print("=== BUILDING OPENAI GRAPH ===")
|
| 147 |
+
|
| 148 |
+
# Check for OpenAI API key
|
| 149 |
+
openai_api_key = os.getenv("OPENAI_API_KEY")
|
| 150 |
+
print(f"OpenAI API Key: {'Found' if openai_api_key else 'Not found'}")
|
| 151 |
+
|
| 152 |
+
if openai_api_key:
|
| 153 |
+
print(f"API Key starts with: {openai_api_key[:10]}...")
|
| 154 |
+
|
| 155 |
+
try:
|
| 156 |
+
if openai_api_key and len(openai_api_key.strip()) > 0:
|
| 157 |
+
print("Attempting to initialize OpenAI ChatGPT...")
|
| 158 |
+
|
| 159 |
+
# Initialize OpenAI LLM
|
| 160 |
+
llm = ChatOpenAI(
|
| 161 |
+
model="gpt-3.5-turbo", # You can change to "gpt-4" if you have access
|
| 162 |
+
temperature=0.1,
|
| 163 |
+
api_key=openai_api_key.strip(),
|
| 164 |
+
max_tokens=512
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
# Test the connection
|
| 168 |
+
test_response = llm.invoke([HumanMessage(content="Hello")])
|
| 169 |
+
print("✓ Successfully connected to OpenAI")
|
| 170 |
+
print(f"Test response: {test_response.content[:50]}...")
|
| 171 |
+
|
| 172 |
+
else:
|
| 173 |
+
raise Exception("No valid OPENAI_API_KEY found")
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Error initializing OpenAI LLM: {e}")
|
| 177 |
+
print("Creating functional mock LLM...")
|
| 178 |
+
|
| 179 |
+
class FunctionalMockLLM:
|
| 180 |
+
def bind_tools(self, tools):
|
| 181 |
+
self.tools = tools
|
| 182 |
+
return self
|
| 183 |
+
|
| 184 |
+
def invoke(self, messages):
|
| 185 |
+
from langchain_core.messages import AIMessage
|
| 186 |
+
import json
|
| 187 |
+
import re
|
| 188 |
+
|
| 189 |
+
last_msg = messages[-1] if messages else None
|
| 190 |
+
if not last_msg:
|
| 191 |
+
return AIMessage(content="Please ask me a question!")
|
| 192 |
+
|
| 193 |
+
content = getattr(last_msg, 'content', str(last_msg))
|
| 194 |
+
content_lower = content.lower()
|
| 195 |
+
|
| 196 |
+
# Handle math operations with tool calls
|
| 197 |
+
math_patterns = [
|
| 198 |
+
(r'(\d+)\s*\+\s*(\d+)', 'add'),
|
| 199 |
+
(r'(\d+)\s*-\s*(\d+)', 'subtract'),
|
| 200 |
+
(r'(\d+)\s*\*\s*(\d+)', 'multiply'),
|
| 201 |
+
(r'(\d+)\s*/\s*(\d+)', 'divide'),
|
| 202 |
+
(r'(\d+)\s*%\s*(\d+)', 'modulus'),
|
| 203 |
+
]
|
| 204 |
+
|
| 205 |
+
for pattern, operation in math_patterns:
|
| 206 |
+
match = re.search(pattern, content)
|
| 207 |
+
if match:
|
| 208 |
+
a, b = int(match.group(1)), int(match.group(2))
|
| 209 |
+
|
| 210 |
+
tool_call = {
|
| 211 |
+
"name": operation,
|
| 212 |
+
"args": {"a": a, "b": b},
|
| 213 |
+
"id": f"call_{operation}_{a}_{b}"
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
return AIMessage(
|
| 217 |
+
content=f"I'll {operation} {a} and {b} for you.",
|
| 218 |
+
tool_calls=[tool_call]
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
# Handle search requests
|
| 222 |
+
if any(word in content_lower for word in ['search', 'find', 'look up', 'what is', 'who is', 'tell me about']):
|
| 223 |
+
# Extract search query
|
| 224 |
+
search_query = content
|
| 225 |
+
for phrase in ['search for', 'find', 'look up', 'what is', 'who is', 'tell me about']:
|
| 226 |
+
search_query = search_query.lower().replace(phrase, '').strip()
|
| 227 |
+
|
| 228 |
+
if len(search_query) > 100:
|
| 229 |
+
search_query = search_query[:100]
|
| 230 |
+
|
| 231 |
+
if 'wikipedia' in content_lower:
|
| 232 |
+
tool_name = "wiki_search"
|
| 233 |
+
elif 'arxiv' in content_lower or 'research' in content_lower or 'paper' in content_lower:
|
| 234 |
+
tool_name = "arxiv_search"
|
| 235 |
+
else:
|
| 236 |
+
tool_name = "web_search"
|
| 237 |
+
|
| 238 |
+
tool_call = {
|
| 239 |
+
"name": tool_name,
|
| 240 |
+
"args": {"query": search_query},
|
| 241 |
+
"id": f"call_{tool_name}_{hash(search_query) % 1000}"
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
return AIMessage(
|
| 245 |
+
content=f"I'll search for information about: {search_query}",
|
| 246 |
+
tool_calls=[tool_call]
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Default response for other questions
|
| 250 |
+
return AIMessage(content=f"I understand you're asking: {content[:200]}... I can help with math calculations and information searches. Please configure OPENAI_API_KEY for full functionality, or try asking me to calculate something or search for information.")
|
| 251 |
+
|
| 252 |
+
llm = FunctionalMockLLM()
|
| 253 |
+
print("��� Using functional mock LLM")
|
| 254 |
+
|
| 255 |
+
# Bind tools to LLM
|
| 256 |
+
llm_with_tools = llm.bind_tools(tools)
|
| 257 |
+
|
| 258 |
+
def retriever(state: MessagesState):
|
| 259 |
+
"""Add system message and handle retrieval if Supabase is available"""
|
| 260 |
+
messages = [sys_msg] + state["messages"]
|
| 261 |
+
|
| 262 |
+
if supabase and len(tools) > 8: # Check if retriever tool was added
|
| 263 |
+
try:
|
| 264 |
+
query = state["messages"][-1].content
|
| 265 |
+
docs = vector_store.similarity_search(query, k=1)
|
| 266 |
+
if docs:
|
| 267 |
+
doc = docs[0]
|
| 268 |
+
content = doc.page_content
|
| 269 |
+
answer = content.split("Final answer :")[-1].strip() if "Final answer :" in content else content.strip()
|
| 270 |
+
return {"messages": messages + [AIMessage(content=f"Retrieved context: {answer}")]}
|
| 271 |
+
except Exception as e:
|
| 272 |
+
print(f"Retrieval error: {e}")
|
| 273 |
+
|
| 274 |
+
return {"messages": messages}
|
| 275 |
+
|
| 276 |
+
def assistant(state: MessagesState):
|
| 277 |
+
"""Main assistant function"""
|
| 278 |
+
try:
|
| 279 |
+
response = llm_with_tools.invoke(state["messages"])
|
| 280 |
+
return {"messages": [response]}
|
| 281 |
+
except Exception as e:
|
| 282 |
+
print(f"Assistant error: {e}")
|
| 283 |
+
return {"messages": [AIMessage(content=f"I encountered an error: {str(e)}. Please make sure your OPENAI_API_KEY is configured correctly.")]}
|
| 284 |
+
|
| 285 |
+
# Build the graph
|
| 286 |
+
g = StateGraph(MessagesState)
|
| 287 |
+
g.add_node("retriever", retriever)
|
| 288 |
+
g.add_node("assistant", assistant)
|
| 289 |
+
g.add_node("tools", ToolNode(tools))
|
| 290 |
+
|
| 291 |
+
# Define edges
|
| 292 |
+
g.add_edge(START, "retriever")
|
| 293 |
+
g.add_edge("retriever", "assistant")
|
| 294 |
+
g.add_conditional_edges("assistant", tools_condition)
|
| 295 |
+
g.add_edge("tools", "assistant")
|
| 296 |
+
|
| 297 |
+
print("✓ Graph compiled successfully")
|
| 298 |
+
return g.compile()
|
app.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import requests
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from langchain_core.messages import HumanMessage
|
| 6 |
+
from agent import build_graph
|
| 7 |
+
|
| 8 |
+
# --- Constants ---
|
| 9 |
+
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 10 |
+
|
| 11 |
+
class BasicAgent:
|
| 12 |
+
"""A langgraph agent using OpenAI."""
|
| 13 |
+
def __init__(self):
|
| 14 |
+
print("=== INITIALIZING OPENAI BASIC AGENT ===")
|
| 15 |
+
print(f"Current working directory: {os.getcwd()}")
|
| 16 |
+
print(f"Files in directory: {os.listdir('.')}")
|
| 17 |
+
|
| 18 |
+
# Check environment variables
|
| 19 |
+
print("=== ENVIRONMENT VARIABLES ===")
|
| 20 |
+
for key in sorted(os.environ.keys()):
|
| 21 |
+
if any(term in key.upper() for term in ['OPENAI', 'API_KEY', 'TOKEN', 'TAVILY']):
|
| 22 |
+
value = os.environ[key]
|
| 23 |
+
print(f"{key}: {value[:10] if value else 'None'}...")
|
| 24 |
+
|
| 25 |
+
# Check specifically for OpenAI API key
|
| 26 |
+
openai_key = os.getenv("OPENAI_API_KEY")
|
| 27 |
+
if openai_key:
|
| 28 |
+
print(f"✓ OpenAI API Key found: {openai_key[:15]}...")
|
| 29 |
+
else:
|
| 30 |
+
print("✗ OpenAI API Key not found!")
|
| 31 |
+
print("Please add OPENAI_API_KEY to your Hugging Face Space secrets")
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
self.graph = build_graph()
|
| 35 |
+
print("✓ Graph built successfully")
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"✗ Error building graph: {e}")
|
| 38 |
+
raise e
|
| 39 |
+
|
| 40 |
+
def __call__(self, question: str) -> str:
|
| 41 |
+
print(f"=== AGENT CALL ===")
|
| 42 |
+
print(f"Question: {question[:100]}...")
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
messages = [HumanMessage(content=question)]
|
| 46 |
+
print(f"Invoking graph with messages: {len(messages)}")
|
| 47 |
+
|
| 48 |
+
result = self.graph.invoke({"messages": messages})
|
| 49 |
+
print(f"Graph result keys: {result.keys() if isinstance(result, dict) else 'Not a dict'}")
|
| 50 |
+
|
| 51 |
+
if 'messages' in result and result['messages']:
|
| 52 |
+
answer = result['messages'][-1].content
|
| 53 |
+
print(f"Answer (first 100 chars): {answer[:100]}...")
|
| 54 |
+
return answer
|
| 55 |
+
else:
|
| 56 |
+
print("No messages in result")
|
| 57 |
+
return "I apologize, but I couldn't generate a response."
|
| 58 |
+
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error in agent call: {e}")
|
| 61 |
+
return f"Error: {str(e)}"
|
| 62 |
+
|
| 63 |
+
def run_and_submit_all(profile: gr.OAuthProfile | None):
|
| 64 |
+
print("=== STARTING RUN AND SUBMIT ===")
|
| 65 |
+
space_id = os.getenv("SPACE_ID")
|
| 66 |
+
print(f"Space ID: {space_id}")
|
| 67 |
+
|
| 68 |
+
if not profile:
|
| 69 |
+
return "Please Login to Hugging Face with the button.", None
|
| 70 |
+
|
| 71 |
+
username = profile.username
|
| 72 |
+
print(f"Username: {username}")
|
| 73 |
+
|
| 74 |
+
api_url = DEFAULT_API_URL
|
| 75 |
+
questions_url = f"{api_url}/questions"
|
| 76 |
+
submit_url = f"{api_url}/submit"
|
| 77 |
+
|
| 78 |
+
print("=== INITIALIZING AGENT ===")
|
| 79 |
+
try:
|
| 80 |
+
agent = BasicAgent()
|
| 81 |
+
print("✓ Agent initialized successfully")
|
| 82 |
+
except Exception as e:
|
| 83 |
+
error_msg = f"Error initializing agent: {e}"
|
| 84 |
+
print(error_msg)
|
| 85 |
+
return error_msg, None
|
| 86 |
+
|
| 87 |
+
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
|
| 88 |
+
print(f"Agent code URL: {agent_code}")
|
| 89 |
+
|
| 90 |
+
print("=== FETCHING QUESTIONS ===")
|
| 91 |
+
try:
|
| 92 |
+
resp_q = requests.get(questions_url, timeout=15)
|
| 93 |
+
resp_q.raise_for_status()
|
| 94 |
+
questions = resp_q.json()
|
| 95 |
+
print(f"✓ Fetched {len(questions)} questions")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
error_msg = f"Error fetching questions: {e}"
|
| 98 |
+
print(error_msg)
|
| 99 |
+
return error_msg, None
|
| 100 |
+
|
| 101 |
+
results_log = []
|
| 102 |
+
answers = []
|
| 103 |
+
|
| 104 |
+
print("=== PROCESSING QUESTIONS ===")
|
| 105 |
+
for i, item in enumerate(questions):
|
| 106 |
+
task_id = item.get("task_id")
|
| 107 |
+
q = item.get("question")
|
| 108 |
+
|
| 109 |
+
print(f"\n--- Question {i+1}/{len(questions)} ---")
|
| 110 |
+
print(f"Task ID: {task_id}")
|
| 111 |
+
print(f"Question: {q[:100]}...")
|
| 112 |
+
|
| 113 |
+
if not task_id or q is None:
|
| 114 |
+
print("Skipping - missing task_id or question")
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
try:
|
| 118 |
+
print("Calling agent...")
|
| 119 |
+
ans = agent(q)
|
| 120 |
+
print(f"Answer: {ans[:100]}...")
|
| 121 |
+
|
| 122 |
+
answers.append({"task_id": task_id, "submitted_answer": ans})
|
| 123 |
+
results_log.append({"Task ID": task_id, "Question": q, "Submitted Answer": ans})
|
| 124 |
+
print("✓ Question processed successfully")
|
| 125 |
+
|
| 126 |
+
except Exception as e:
|
| 127 |
+
error_msg = f"ERROR: {e}"
|
| 128 |
+
print(f"✗ Error processing question: {error_msg}")
|
| 129 |
+
results_log.append({"Task ID": task_id, "Question": q, "Submitted Answer": error_msg})
|
| 130 |
+
|
| 131 |
+
if not answers:
|
| 132 |
+
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)
|
| 133 |
+
|
| 134 |
+
print(f"=== SUBMITTING {len(answers)} ANSWERS ===")
|
| 135 |
+
payload = {"username": username.strip(), "agent_code": agent_code, "answers": answers}
|
| 136 |
+
|
| 137 |
+
try:
|
| 138 |
+
resp_s = requests.post(submit_url, json=payload, timeout=60)
|
| 139 |
+
resp_s.raise_for_status()
|
| 140 |
+
data = resp_s.json()
|
| 141 |
+
|
| 142 |
+
status = (
|
| 143 |
+
f"Submission Successful!\n"
|
| 144 |
+
f"User: {data.get('username')}\n"
|
| 145 |
+
f"Score: {data.get('score', 'N/A')}% "
|
| 146 |
+
f"({data.get('correct_count', '?')}/{data.get('total_attempted', '?')})\n"
|
| 147 |
+
f"{data.get('message', '')}"
|
| 148 |
+
)
|
| 149 |
+
print("✓ Submission successful")
|
| 150 |
+
print(status)
|
| 151 |
+
return status, pd.DataFrame(results_log)
|
| 152 |
+
|
| 153 |
+
except Exception as e:
|
| 154 |
+
error_msg = f"Submission Failed: {e}"
|
| 155 |
+
print(error_msg)
|
| 156 |
+
return error_msg, pd.DataFrame(results_log)
|
| 157 |
+
|
| 158 |
+
# Simple test function for debugging
|
| 159 |
+
def test_agent():
|
| 160 |
+
"""Test function to verify agent works"""
|
| 161 |
+
print("=== TESTING AGENT ===")
|
| 162 |
+
try:
|
| 163 |
+
agent = BasicAgent()
|
| 164 |
+
test_questions = [
|
| 165 |
+
"What is 2 + 3?",
|
| 166 |
+
"What is 10 * 5?",
|
| 167 |
+
"Search for information about Python programming"
|
| 168 |
+
]
|
| 169 |
+
|
| 170 |
+
for q in test_questions:
|
| 171 |
+
print(f"\nTesting: {q}")
|
| 172 |
+
answer = agent(q)
|
| 173 |
+
print(f"Answer: {answer}")
|
| 174 |
+
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Test failed: {e}")
|
| 177 |
+
|
| 178 |
+
with gr.Blocks() as demo:
|
| 179 |
+
gr.Markdown("# OpenAI-Powered Agent Evaluation Runner")
|
| 180 |
+
gr.Markdown("""
|
| 181 |
+
This agent uses OpenAI's GPT models instead of Hugging Face.
|
| 182 |
+
|
| 183 |
+
## Setup Instructions:
|
| 184 |
+
1. Get an OpenAI API key from https://platform.openai.com/api-keys
|
| 185 |
+
2. Add it as `OPENAI_API_KEY` in your Hugging Face Space secrets
|
| 186 |
+
3. (Optional) Add `TAVILY_API_KEY` for web search functionality
|
| 187 |
+
4. Log in with the button below
|
| 188 |
+
5. Click **Run Evaluation & Submit All Answers**
|
| 189 |
+
|
| 190 |
+
## Current Configuration:
|
| 191 |
+
- Model: GPT-3.5-turbo (change to GPT-4 in agent.py if you have access)
|
| 192 |
+
- Tools: Math operations, Wikipedia search, Arxiv search, Web search (if Tavily configured)
|
| 193 |
+
""")
|
| 194 |
+
|
| 195 |
+
with gr.Row():
|
| 196 |
+
gr.LoginButton()
|
| 197 |
+
test_btn = gr.Button("Test Agent", variant="secondary")
|
| 198 |
+
|
| 199 |
+
run_btn = gr.Button("Run Evaluation & Submit All Answers", variant="primary")
|
| 200 |
+
status_out = gr.Textbox(label="Run Status / Submission Result", lines=5)
|
| 201 |
+
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
|
| 202 |
+
|
| 203 |
+
# Button actions
|
| 204 |
+
run_btn.click(fn=run_and_submit_all, outputs=[status_out, results_table])
|
| 205 |
+
test_btn.click(fn=test_agent, outputs=[])
|
| 206 |
+
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
print("=== STARTING OPENAI GRADIO APP ===")
|
| 209 |
+
|
| 210 |
+
# Quick environment check
|
| 211 |
+
openai_key = os.getenv("OPENAI_API_KEY")
|
| 212 |
+
if openai_key:
|
| 213 |
+
print(f"✓ OpenAI API Key configured: {openai_key[:15]}...")
|
| 214 |
+
else:
|
| 215 |
+
print("⚠️ OpenAI API Key not found - please add OPENAI_API_KEY to your Space secrets")
|
| 216 |
+
|
| 217 |
+
demo.launch(debug=True, share=False)
|