|
|
import os |
|
|
import gradio as gr |
|
|
import requests |
|
|
import inspect |
|
|
import pandas as pd |
|
|
import asyncio |
|
|
from langchain_google_genai import ChatGoogleGenerativeAI |
|
|
from typing import IO, Dict, TypedDict, Annotated, Sequence, Any, Callable |
|
|
from io import BytesIO |
|
|
from langchain_core.messages import HumanMessage, SystemMessage, BaseMessage, AIMessage |
|
|
from langgraph.graph import StateGraph, END |
|
|
import base64 |
|
|
from google.ai.generativelanguage_v1beta.types import Tool as GenAITool |
|
|
import google.generativeai as genai |
|
|
import operator |
|
|
from langchain_core.tools import tool |
|
|
from utilities import get_file |
|
|
import time |
|
|
from tenacity import retry, stop_after_attempt, wait_exponential |
|
|
import json |
|
|
import logging |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") |
|
|
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space" |
|
|
GEMINI_API_KEY = os.getenv("Gemini_API_key") |
|
|
SERPER_API_KEY = os.getenv("SERPER_API_KEY") |
|
|
|
|
|
|
|
|
class AgentState(TypedDict): |
|
|
messages: Annotated[Sequence[BaseMessage], operator.add] |
|
|
next: str |
|
|
|
|
|
|
|
|
def execute_tool(tool_name: str, tool_args: Dict[str, Any], tools: Dict[str, Callable]) -> Any: |
|
|
"""Execute a tool with the given arguments.""" |
|
|
if tool_name not in tools: |
|
|
raise ValueError(f"Tool {tool_name} not found") |
|
|
tool_func = tools[tool_name] |
|
|
|
|
|
if hasattr(tool_func, 'run'): |
|
|
return tool_func.run(**tool_args) |
|
|
|
|
|
return tool_func(**tool_args) |
|
|
|
|
|
|
|
|
@tool |
|
|
def analyse_excel(task_id: str) -> Dict[str, float]: |
|
|
'''Analyzes the Excel file associated with the given task_id.''' |
|
|
excel_file = get_file(task_id) |
|
|
df = pd.read_excel(excel_file, sheet_name=0) |
|
|
return df.select_dtypes(include='number').sum().to_dict() |
|
|
|
|
|
@tool |
|
|
def add_numbers(a: float, b: float) -> float: |
|
|
'''Adds two numbers together.''' |
|
|
return a + b |
|
|
|
|
|
@tool |
|
|
def transcribe_audio(task_id: str) -> HumanMessage: |
|
|
'''Transcribes an audio file.''' |
|
|
audio_file = get_file(task_id) |
|
|
if audio_file is None: |
|
|
raise ValueError("No audio file found for the given task_id.") |
|
|
audio_file.seek(0) |
|
|
encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8") |
|
|
return HumanMessage( |
|
|
content=[ |
|
|
{"type": "text", "text": "Transcribe the audio."}, |
|
|
{ |
|
|
"type": "media", |
|
|
"data": encoded_audio, |
|
|
"mime_type": "audio/mpeg", |
|
|
}, |
|
|
] |
|
|
) |
|
|
|
|
|
@tool |
|
|
def python_code(task_id: str) -> str: |
|
|
'''Returns the Python code associated with the given task_id.''' |
|
|
code_request = requests.get(url=f'{DEFAULT_API_URL}/files/{task_id}') |
|
|
code_request.raise_for_status() |
|
|
return code_request.text |
|
|
|
|
|
@tool |
|
|
def open_image(task_id: str) -> str: |
|
|
'''Opens an image file associated with the given task_id.''' |
|
|
image_file = get_file(task_id) |
|
|
if image_file is None: |
|
|
raise ValueError("No image file found for the given task_id.") |
|
|
return base64.b64encode(image_file.read()).decode("utf-8") |
|
|
|
|
|
@tool |
|
|
def open_youtube_video(url: str, query: str) -> str: |
|
|
'''Answers a question about a video from the given URL.''' |
|
|
client = genai.Client(api_key=GOOGLE_API_KEY) |
|
|
response = client.models.generate_content( |
|
|
model='models/gemini-2.0-flash', |
|
|
contents=types.Content( |
|
|
parts=[ |
|
|
types.Part(file_data=types.FileData(file_uri=url)), |
|
|
types.Part(text=f'''{query} YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated |
|
|
list of numbers and/or strings.''') |
|
|
] |
|
|
) |
|
|
) |
|
|
return response.text |
|
|
|
|
|
def google_search(query: str) -> str: |
|
|
'''Performs a web search for the given query using DuckDuckGo, and falls back to Wikipedia if no results are found.''' |
|
|
try: |
|
|
|
|
|
time.sleep(1) |
|
|
|
|
|
response = requests.get( |
|
|
"https://api.duckduckgo.com/", |
|
|
params={ |
|
|
"q": query, |
|
|
"format": "json", |
|
|
"no_html": 1, |
|
|
"no_redirect": 1, |
|
|
"skip_disambig": 1 |
|
|
}, |
|
|
timeout=10 |
|
|
) |
|
|
response.raise_for_status() |
|
|
data = response.json() |
|
|
|
|
|
|
|
|
result = [] |
|
|
|
|
|
|
|
|
if data.get("Abstract"): |
|
|
result.append(data["Abstract"]) |
|
|
|
|
|
|
|
|
if data.get("Definition"): |
|
|
result.append(f"Definition: {data['Definition']}") |
|
|
|
|
|
|
|
|
if data.get("RelatedTopics"): |
|
|
for topic in data["RelatedTopics"][:5]: |
|
|
if "Text" in topic: |
|
|
result.append(topic["Text"]) |
|
|
elif "Topics" in topic: |
|
|
for subtopic in topic["Topics"][:2]: |
|
|
if "Text" in subtopic: |
|
|
result.append(subtopic["Text"]) |
|
|
|
|
|
|
|
|
if data.get("Answer"): |
|
|
result.append(f"Answer: {data['Answer']}") |
|
|
|
|
|
|
|
|
if not result: |
|
|
|
|
|
quoted_response = requests.get( |
|
|
"https://api.duckduckgo.com/", |
|
|
params={ |
|
|
"q": f'"{query}"', |
|
|
"format": "json", |
|
|
"no_html": 1, |
|
|
"no_redirect": 1 |
|
|
}, |
|
|
timeout=10 |
|
|
) |
|
|
quoted_data = quoted_response.json() |
|
|
|
|
|
if quoted_data.get("Abstract"): |
|
|
result.append(quoted_data["Abstract"]) |
|
|
if quoted_data.get("Answer"): |
|
|
result.append(f"Answer: {quoted_data['Answer']}") |
|
|
|
|
|
|
|
|
if not result: |
|
|
try: |
|
|
wiki_response = requests.get( |
|
|
"https://en.wikipedia.org/w/api.php", |
|
|
params={ |
|
|
"action": "query", |
|
|
"format": "json", |
|
|
"prop": "extracts", |
|
|
"exintro": True, |
|
|
"explaintext": True, |
|
|
"titles": query |
|
|
}, |
|
|
timeout=10 |
|
|
) |
|
|
wiki_data = wiki_response.json() |
|
|
pages = wiki_data.get("query", {}).get("pages", {}) |
|
|
for page in pages.values(): |
|
|
extract = page.get("extract") |
|
|
if extract: |
|
|
result.append(extract) |
|
|
except Exception as e: |
|
|
logger.error(f"Wikipedia fallback error for query {query}: {str(e)}") |
|
|
|
|
|
return "\n".join(result) if result else "No results found." |
|
|
|
|
|
except requests.exceptions.Timeout: |
|
|
logger.warning(f"Search timeout for query: {query}") |
|
|
return "Search timed out. Please try again." |
|
|
except requests.exceptions.RequestException as e: |
|
|
logger.error(f"Search error for query {query}: {str(e)}") |
|
|
return f"Error performing search: {str(e)}" |
|
|
except Exception as e: |
|
|
logger.error(f"Unexpected error during search for query {query}: {str(e)}") |
|
|
return "An unexpected error occurred during the search." |
|
|
|
|
|
def log_message(message: BaseMessage, prefix: str = ""): |
|
|
"""Helper function to log message details.""" |
|
|
if isinstance(message, HumanMessage): |
|
|
logger.info(f"{prefix}Human Message:") |
|
|
if isinstance(message.content, list): |
|
|
for item in message.content: |
|
|
if isinstance(item, dict): |
|
|
if item.get("type") == "media": |
|
|
logger.info(f"{prefix} Media content (type: {item.get('mime_type')})") |
|
|
else: |
|
|
logger.info(f"{prefix} {item.get('type')}: {item.get('text')}") |
|
|
else: |
|
|
logger.info(f"{prefix} {item}") |
|
|
else: |
|
|
logger.info(f"{prefix} {message.content}") |
|
|
elif isinstance(message, AIMessage): |
|
|
logger.info(f"{prefix}AI Message:") |
|
|
logger.info(f"{prefix} Content: {message.content}") |
|
|
if hasattr(message, 'tool_calls') and message.tool_calls: |
|
|
logger.info(f"{prefix} Tool Calls: {json.dumps(message.tool_calls, indent=2)}") |
|
|
elif isinstance(message, SystemMessage): |
|
|
logger.info(f"{prefix}System Message:") |
|
|
logger.info(f"{prefix} {message.content}") |
|
|
|
|
|
class BasicAgent: |
|
|
def __init__(self): |
|
|
|
|
|
if os.getenv("OPENAI_API_KEY"): |
|
|
from langchain_openai import ChatOpenAI |
|
|
self.primary_llm = ChatOpenAI( |
|
|
model="gpt-3.5-turbo", |
|
|
temperature=0, |
|
|
max_tokens=4096 |
|
|
) |
|
|
else: |
|
|
self.primary_llm = None |
|
|
|
|
|
|
|
|
self.fallback_llm = ChatGoogleGenerativeAI( |
|
|
model="gemini-2.5-flash-preview-05-20", |
|
|
max_tokens=8192, |
|
|
temperature=0, |
|
|
convert_system_message_to_human=True |
|
|
) |
|
|
|
|
|
|
|
|
self.tools = { |
|
|
"get_file": get_file, |
|
|
"analyse_excel": analyse_excel, |
|
|
"add_numbers": add_numbers, |
|
|
"transcribe_audio": transcribe_audio, |
|
|
"python_code": python_code, |
|
|
"open_image": open_image, |
|
|
"open_youtube_video": open_youtube_video, |
|
|
"google_search": google_search |
|
|
} |
|
|
|
|
|
|
|
|
self.openai_tools = [{ |
|
|
"type": "function", |
|
|
"function": { |
|
|
"name": "google_search", |
|
|
"description": "Search for information on the web. Use this tool to find specific information about the question.", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"query": { |
|
|
"type": "string", |
|
|
"description": "The search query to find relevant information" |
|
|
} |
|
|
}, |
|
|
"required": ["query"] |
|
|
} |
|
|
} |
|
|
}] |
|
|
|
|
|
self.gemini_tools = [{ |
|
|
"function_declarations": [{ |
|
|
"name": "google_search", |
|
|
"description": "Search for information on the web. Use this tool to find specific information about the question.", |
|
|
"parameters": { |
|
|
"type": "object", |
|
|
"properties": { |
|
|
"query": { |
|
|
"type": "string", |
|
|
"description": "The search query to find relevant information" |
|
|
} |
|
|
}, |
|
|
"required": ["query"] |
|
|
} |
|
|
}] |
|
|
}] |
|
|
|
|
|
|
|
|
self.sys_msg = SystemMessage('''You are a general AI assistant. I will ask you a question. Follow these steps: |
|
|
|
|
|
1. First, use the google_search tool to find relevant information about the question. |
|
|
2. Analyze the search results to find the specific information needed. |
|
|
3. If needed, use additional tools to gather more information. |
|
|
4. Only after gathering all necessary information, provide YOUR FINAL ANSWER. |
|
|
|
|
|
YOUR FINAL ANSWER must be: |
|
|
- For numbers: Just the digit (e.g., "7" not "seven" or "7 albums") |
|
|
- For strings: As few words as possible |
|
|
- For lists: A comma-separated list of numbers and/or strings |
|
|
|
|
|
Rules for formatting: |
|
|
- For numbers: Don't use commas or units ($, %, etc.) unless specified |
|
|
- For strings: Don't use articles or abbreviations |
|
|
- For lists: Apply the above rules based on whether each element is a number or string |
|
|
|
|
|
IMPORTANT: |
|
|
- You MUST use the google_search tool before providing your final answer |
|
|
- Format your tool calls as: {"name": "google_search", "arguments": {"query": "your search query"}} |
|
|
- Your final answer should ONLY be the requested information, no explanations |
|
|
- If you need to search again, use the tool again |
|
|
- Do not provide detailed analysis in your final answer |
|
|
- If you encounter rate limits, inform the user that you need to search for information |
|
|
- Never make up information - if you can't find it, say so''') |
|
|
|
|
|
|
|
|
self.workflow = StateGraph(AgentState) |
|
|
|
|
|
|
|
|
self.workflow.add_node("agent", self.call_model) |
|
|
self.workflow.add_node("tools", self.call_tools) |
|
|
|
|
|
|
|
|
self.workflow.add_edge("agent", "tools") |
|
|
self.workflow.add_edge("tools", "agent") |
|
|
|
|
|
|
|
|
self.workflow.set_entry_point("agent") |
|
|
|
|
|
|
|
|
self.app = self.workflow.compile() |
|
|
|
|
|
|
|
|
os.environ["LANGRAPH_RECURSION_LIMIT"] = "50" |
|
|
|
|
|
logger.info("BasicAgent initialized with fallback LLM support.") |
|
|
|
|
|
def _call_model_with_retry(self, state: AgentState) -> AgentState: |
|
|
"""Internal method to handle retries for model calls.""" |
|
|
max_retries = 3 |
|
|
retry_count = 0 |
|
|
last_error = None |
|
|
|
|
|
while retry_count < max_retries: |
|
|
try: |
|
|
messages = state["messages"] |
|
|
logger.info("\n=== Model Input ===") |
|
|
log_message(self.sys_msg, " ") |
|
|
for msg in messages: |
|
|
log_message(msg, " ") |
|
|
|
|
|
|
|
|
try: |
|
|
if self.primary_llm is None: |
|
|
raise ValueError("Primary LLM not initialized") |
|
|
|
|
|
logger.info("Attempting to use primary LLM (OpenAI)") |
|
|
|
|
|
|
|
|
messages_with_tool_prompt = [self.sys_msg] + messages + [ |
|
|
HumanMessage(content="Use the google_search tool to find the information. Format your response as a JSON object with 'name' and 'arguments' fields.") |
|
|
] |
|
|
|
|
|
response = self.primary_llm.invoke( |
|
|
messages_with_tool_prompt, |
|
|
tools=self.openai_tools |
|
|
) |
|
|
|
|
|
if not response or not hasattr(response, 'content'): |
|
|
raise ValueError("Invalid response format from OpenAI") |
|
|
|
|
|
|
|
|
if hasattr(response, 'tool_calls') and response.tool_calls: |
|
|
logger.info("Successfully used primary LLM with tools") |
|
|
return {"messages": [response], "next": "tools"} |
|
|
else: |
|
|
|
|
|
response = self.primary_llm.invoke(messages_with_tool_prompt) |
|
|
if not response or not hasattr(response, 'content'): |
|
|
raise ValueError("Invalid response format from OpenAI") |
|
|
logger.info("Successfully used primary LLM without tools") |
|
|
|
|
|
except Exception as e: |
|
|
error_str = str(e) |
|
|
logger.error(f"Primary LLM error: {error_str}") |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Attempting to use fallback LLM (Gemini)") |
|
|
|
|
|
if isinstance(self.sys_msg, SystemMessage): |
|
|
system_content = f"System Instructions: {self.sys_msg.content}" |
|
|
messages_with_system = [HumanMessage(content=system_content)] + messages |
|
|
else: |
|
|
messages_with_system = [self.sys_msg] + messages |
|
|
|
|
|
|
|
|
messages_with_tool_prompt = messages_with_system + [ |
|
|
HumanMessage(content="Use the google_search tool to find the information. Format your response as a JSON object with 'name' and 'arguments' fields.") |
|
|
] |
|
|
|
|
|
response = self.fallback_llm.invoke( |
|
|
messages_with_tool_prompt, |
|
|
tools=self.gemini_tools |
|
|
) |
|
|
|
|
|
if not response or not hasattr(response, 'content'): |
|
|
raise ValueError("Invalid response format from Gemini") |
|
|
|
|
|
|
|
|
if hasattr(response, 'tool_calls') and response.tool_calls: |
|
|
logger.info("Successfully used fallback LLM with tools") |
|
|
return {"messages": [response], "next": "tools"} |
|
|
else: |
|
|
|
|
|
response = self.fallback_llm.invoke(messages_with_tool_prompt) |
|
|
if not response or not hasattr(response, 'content'): |
|
|
raise ValueError("Invalid response format from Gemini") |
|
|
logger.info("Successfully used fallback LLM without tools") |
|
|
|
|
|
except Exception as fallback_error: |
|
|
logger.error(f"Fallback LLM error: {str(fallback_error)}") |
|
|
if "429" in str(fallback_error): |
|
|
return { |
|
|
"messages": [AIMessage(content="All LLM services are currently rate limited. Please try again later.")], |
|
|
"next": END |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"messages": [AIMessage(content="All LLM services are currently unavailable. Please try again later.")], |
|
|
"next": END |
|
|
} |
|
|
|
|
|
logger.info("\n=== Model Output ===") |
|
|
log_message(response, " ") |
|
|
|
|
|
if not response or not response.content: |
|
|
logger.error("Empty response from model") |
|
|
raise ValueError("Empty response from model") |
|
|
|
|
|
|
|
|
content = response.content.strip() |
|
|
|
|
|
|
|
|
if any(phrase in content.lower() for phrase in ["let me", "i'll", "i will", "sure", "okay", "alright"]): |
|
|
logger.info("Model provided acknowledgment instead of tool call, prompting for search") |
|
|
return { |
|
|
"messages": [AIMessage(content="Please use the google_search tool to find the information.")], |
|
|
"next": "agent" |
|
|
} |
|
|
|
|
|
|
|
|
if content.startswith("**Final Answer**: "): |
|
|
content = content.replace("**Final Answer**: ", "").strip() |
|
|
|
|
|
|
|
|
if content.replace(".", "").isdigit(): |
|
|
|
|
|
if float(content).is_integer(): |
|
|
content = str(int(float(content))) |
|
|
|
|
|
|
|
|
if content.isdigit() or (content.startswith('[') and content.endswith(']')): |
|
|
return {"messages": [AIMessage(content=content)], "next": END} |
|
|
else: |
|
|
|
|
|
return {"messages": [response], "next": "agent"} |
|
|
|
|
|
except Exception as e: |
|
|
last_error = e |
|
|
retry_count += 1 |
|
|
logger.error(f"Error in processing, retry {retry_count}/{max_retries}: {str(e)}") |
|
|
if retry_count < max_retries: |
|
|
wait_time = 5 * retry_count |
|
|
time.sleep(wait_time) |
|
|
else: |
|
|
logger.error(f"All retries failed. Last error: {str(last_error)}") |
|
|
return { |
|
|
"messages": [AIMessage(content="Unable to generate answer after multiple attempts. Please try again later.")], |
|
|
"next": END |
|
|
} |
|
|
|
|
|
return { |
|
|
"messages": [AIMessage(content="Unable to generate answer after multiple attempts. Please try again later.")], |
|
|
"next": END |
|
|
} |
|
|
|
|
|
def call_model(self, state: AgentState) -> AgentState: |
|
|
"""Call the model to generate a response with retry logic and fallback support.""" |
|
|
return self._call_model_with_retry(state) |
|
|
|
|
|
def call_tools(self, state: AgentState) -> AgentState: |
|
|
"""Call the tools based on the model's response.""" |
|
|
try: |
|
|
messages = state["messages"] |
|
|
last_message = messages[-1] |
|
|
|
|
|
logger.info("\n=== Tool Execution ===") |
|
|
if isinstance(last_message, AIMessage): |
|
|
|
|
|
content = last_message.content.strip() |
|
|
try: |
|
|
if content.startswith('{') and content.endswith('}'): |
|
|
tool_call = json.loads(content) |
|
|
if isinstance(tool_call, dict) and 'name' in tool_call and 'arguments' in tool_call: |
|
|
tool_name = tool_call['name'] |
|
|
tool_args = tool_call['arguments'] |
|
|
logger.info(f"Executing tool: {tool_name}") |
|
|
logger.info(f"Tool arguments: {json.dumps(tool_args, indent=2)}") |
|
|
|
|
|
result = execute_tool(tool_name, tool_args, self.tools) |
|
|
logger.info(f"Tool result: {result}") |
|
|
|
|
|
|
|
|
messages.append(AIMessage(content=f"Tool result: {result}")) |
|
|
|
|
|
|
|
|
if isinstance(result, str) and "no results found" in result.lower(): |
|
|
return {"messages": [AIMessage(content="Not found")], "next": END} |
|
|
|
|
|
|
|
|
if tool_name == "google_search": |
|
|
|
|
|
messages.append(HumanMessage(content="Based on the search results, please provide your final answer.")) |
|
|
return {"messages": messages, "next": "agent"} |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
|
|
|
content = last_message.content.strip().lower() |
|
|
if any(phrase in content for phrase in ["let me", "i'll", "i will", "sure", "okay", "alright"]): |
|
|
logger.info("No tool calls found, prompting for search") |
|
|
messages.append(AIMessage(content="Please use the google_search tool to find the information.")) |
|
|
else: |
|
|
logger.info("No tool calls found in AI message") |
|
|
|
|
|
if content.isdigit() or (content.startswith('[') and content.endswith(']')): |
|
|
return {"messages": [last_message], "next": END} |
|
|
else: |
|
|
|
|
|
return {"messages": messages, "next": "agent"} |
|
|
|
|
|
return {"messages": messages, "next": "agent"} |
|
|
except Exception as e: |
|
|
logger.error(f"Error in call_tools: {str(e)}") |
|
|
|
|
|
return {"messages": messages, "next": "agent"} |
|
|
|
|
|
async def __call__(self, question: str, task_id: str) -> str: |
|
|
"""Process a question and return the answer with error handling.""" |
|
|
logger.info(f"\n=== Processing Question ===") |
|
|
logger.info(f"Task ID: {task_id}") |
|
|
logger.info(f"Question: {question}") |
|
|
|
|
|
try: |
|
|
|
|
|
initial_state = { |
|
|
"messages": [HumanMessage(content=f'Task id: {task_id}\n {question}')], |
|
|
"next": "agent" |
|
|
} |
|
|
|
|
|
|
|
|
result = self.app.invoke(initial_state) |
|
|
final_message = result["messages"][-1] |
|
|
|
|
|
if isinstance(final_message, AIMessage) and final_message.content: |
|
|
logger.info(f"\n=== Final Answer ===") |
|
|
logger.info(f"Answer: {final_message.content}") |
|
|
return final_message.content |
|
|
else: |
|
|
logger.error("Empty or invalid response") |
|
|
raise ValueError("Empty or invalid response") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Fatal error in agent: {str(e)}") |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
def run_and_submit_all(profile): |
|
|
""" |
|
|
Fetches all questions, runs the BasicAgent on them, submits all answers, |
|
|
and displays the results. |
|
|
""" |
|
|
|
|
|
space_id = os.getenv("SPACE_ID") |
|
|
|
|
|
if profile: |
|
|
username = str(profile) |
|
|
print(f"User logged in: {username}") |
|
|
else: |
|
|
print("User not logged in.") |
|
|
return "Please Login to Hugging Face with the button.", None |
|
|
|
|
|
api_url = DEFAULT_API_URL |
|
|
questions_url = f"{api_url}/questions" |
|
|
submit_url = f"{api_url}/submit" |
|
|
|
|
|
|
|
|
try: |
|
|
agent = BasicAgent() |
|
|
except Exception as e: |
|
|
print(f"Error instantiating agent: {e}") |
|
|
return f"Error initializing agent: {e}", None |
|
|
|
|
|
agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" |
|
|
print(agent_code) |
|
|
|
|
|
|
|
|
print(f"Fetching questions from: {questions_url}") |
|
|
try: |
|
|
response = requests.get(questions_url, timeout=15) |
|
|
response.raise_for_status() |
|
|
questions_data = response.json() |
|
|
if not questions_data: |
|
|
print("Fetched questions list is empty.") |
|
|
return "Fetched questions list is empty or invalid format.", None |
|
|
print(f"Fetched {len(questions_data)} questions.") |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"Error fetching questions: {e}") |
|
|
return f"Error fetching questions: {e}", None |
|
|
except requests.exceptions.JSONDecodeError as e: |
|
|
print(f"Error decoding JSON response from questions endpoint: {e}") |
|
|
print(f"Response text: {response.text[:500]}") |
|
|
return f"Error decoding server response for questions: {e}", None |
|
|
except Exception as e: |
|
|
print(f"An unexpected error occurred fetching questions: {e}") |
|
|
return f"An unexpected error occurred fetching questions: {e}", None |
|
|
|
|
|
|
|
|
results_log = [] |
|
|
answers_payload = [] |
|
|
print(f"Running agent on {len(questions_data)} questions...") |
|
|
for item in questions_data: |
|
|
task_id = item.get("task_id") |
|
|
question_text = item.get("question") |
|
|
if not task_id or question_text is None: |
|
|
print(f"Skipping item with missing task_id or question: {item}") |
|
|
continue |
|
|
try: |
|
|
submitted_answer = asyncio.run(agent(question_text, task_id)) |
|
|
answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) |
|
|
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer}) |
|
|
except Exception as e: |
|
|
print(f"Error running agent on task {task_id}: {e}") |
|
|
results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"}) |
|
|
|
|
|
if not answers_payload: |
|
|
print("Agent did not produce any answers to submit.") |
|
|
return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) |
|
|
|
|
|
|
|
|
submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} |
|
|
status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..." |
|
|
print(status_update) |
|
|
|
|
|
|
|
|
print(f"Submitting {len(answers_payload)} answers to: {submit_url}") |
|
|
try: |
|
|
response = requests.post(submit_url, json=submission_data, timeout=60) |
|
|
response.raise_for_status() |
|
|
result_data = response.json() |
|
|
final_status = ( |
|
|
f"Submission Successful!\n" |
|
|
f"User: {result_data.get('username')}\n" |
|
|
f"Overall Score: {result_data.get('score', 'N/A')}% " |
|
|
f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" |
|
|
f"Message: {result_data.get('message', 'No message received.')}" |
|
|
) |
|
|
print("Submission successful.") |
|
|
results_df = pd.DataFrame(results_log) |
|
|
return final_status, results_df |
|
|
except requests.exceptions.HTTPError as e: |
|
|
error_detail = f"Server responded with status {e.response.status_code}." |
|
|
try: |
|
|
error_json = e.response.json() |
|
|
error_detail += f" Detail: {error_json.get('detail', e.response.text)}" |
|
|
except requests.exceptions.JSONDecodeError: |
|
|
error_detail += f" Response: {e.response.text[:500]}" |
|
|
status_message = f"Submission Failed: {error_detail}" |
|
|
print(status_message) |
|
|
results_df = pd.DataFrame(results_log) |
|
|
return status_message, results_df |
|
|
except requests.exceptions.Timeout: |
|
|
status_message = "Submission Failed: The request timed out." |
|
|
print(status_message) |
|
|
results_df = pd.DataFrame(results_log) |
|
|
return status_message, results_df |
|
|
except requests.exceptions.RequestException as e: |
|
|
status_message = f"Submission Failed: Network error - {e}" |
|
|
print(status_message) |
|
|
results_df = pd.DataFrame(results_log) |
|
|
return status_message, results_df |
|
|
except Exception as e: |
|
|
status_message = f"An unexpected error occurred during submission: {e}" |
|
|
print(status_message) |
|
|
results_df = pd.DataFrame(results_log) |
|
|
return status_message, results_df |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("# Basic Agent Evaluation Runner") |
|
|
gr.Markdown( |
|
|
""" |
|
|
**Instructions:** |
|
|
1. Please clone this space, then modify the code to define your agent's logic, the tools, the necessary packages, etc ... |
|
|
2. Log in to your Hugging Face account using the button below. This uses your HF username for submission. |
|
|
3. Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score. |
|
|
--- |
|
|
**Disclaimers:** |
|
|
Once clicking on the "submit button, it can take quite some time ( this is the time for the agent to go through all the questions). |
|
|
This space provides a basic setup and is intentionally sub-optimal to encourage you to develop your own, more robust solution. For instance for the delay process of the submit button, a solution could be to cache the answers and submit in a seperate action or even to answer the questions in async. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
login_button = gr.LoginButton() |
|
|
run_button = gr.Button("Run Evaluation & Submit All Answers") |
|
|
|
|
|
status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) |
|
|
results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) |
|
|
|
|
|
def run_evaluation(profile): |
|
|
if not profile: |
|
|
return "Please login first.", None |
|
|
return run_and_submit_all(profile) |
|
|
|
|
|
run_button.click( |
|
|
fn=run_evaluation, |
|
|
inputs=[login_button], |
|
|
outputs=[status_output, results_table] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\n" + "-"*30 + " App Starting " + "-"*30) |
|
|
|
|
|
space_host_startup = os.getenv("SPACE_HOST") |
|
|
space_id_startup = os.getenv("SPACE_ID") |
|
|
|
|
|
if space_host_startup: |
|
|
print(f"✅ SPACE_HOST found: {space_host_startup}") |
|
|
print(f" Runtime URL should be: https://{space_host_startup}.hf.space") |
|
|
else: |
|
|
print("ℹ️ SPACE_HOST environment variable not found (running locally?).") |
|
|
|
|
|
if space_id_startup: |
|
|
print(f"✅ SPACE_ID found: {space_id_startup}") |
|
|
print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}") |
|
|
print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main") |
|
|
else: |
|
|
print("ℹ️ SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.") |
|
|
|
|
|
print("-"*(60 + len(" App Starting ")) + "\n") |
|
|
|
|
|
print("Launching Gradio Interface for Basic Agent Evaluation...") |
|
|
demo.launch(debug=True, share=False) |