Spaces:
Configuration error
Configuration error
| import os | |
| import gradio as gr | |
| import requests | |
| import aiohttp | |
| import asyncio | |
| import json | |
| import nest_asyncio | |
| from langgraph.graph import StateGraph, END | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_huggingface import HuggingFacePipeline | |
| from transformers import pipeline | |
| from langchain_core.messages import SystemMessage, HumanMessage | |
| from tools import search_tool, multi_hop_search_tool, file_parser_tool, image_parser_tool, calculator_tool, document_retriever_tool | |
| from tools.search import initialize_search_tools | |
| from state import JARVISState | |
| import pandas as pd | |
| from dotenv import load_dotenv | |
| import logging | |
| from langfuse.callback import CallbackHandler | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| # Apply nest_asyncio | |
| nest_asyncio.apply() | |
| # Load environment variables | |
| load_dotenv() | |
| # Verify environment variables | |
| required_env_vars = ["SPACE_ID", "LANGFUSE_PUBLIC_KEY", "LANGFUSE_SECRET_KEY"] | |
| for var in required_env_vars: | |
| if not os.getenv(var): | |
| raise ValueError(f"Environment variable {var} is not set") | |
| logger.info(f"Environment variables loaded: SPACE_ID={os.getenv('SPACE_ID')[:10]}..., LANGFUSE_HOST={os.getenv('LANGFUSE_HOST', 'https://cloud.langfuse.com')}") | |
| # Initialize Hugging Face model | |
| try: | |
| hf_pipeline = pipeline( | |
| "text-generation", | |
| model="mistralai/Mixtral-7B-Instruct-v0.1", | |
| device_map="auto", | |
| max_new_tokens=512, | |
| do_sample=True, | |
| temperature=0.7 | |
| ) | |
| llm = HuggingFacePipeline(pipeline=hf_pipeline) | |
| logger.info("HuggingFace model initialized: mistralai/Mixtral-7B-Instruct-v0.1") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize HuggingFace model: {e}") | |
| llm = None | |
| # Initialize search tools with LLM | |
| try: | |
| initialize_search_tools(llm) | |
| logger.info("Search tools initialized") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize search tools: {e}") | |
| # Initialize Langfuse | |
| try: | |
| langfuse = CallbackHandler( | |
| public_key=os.getenv("LANGFUSE_PUBLIC_KEY"), | |
| secret_key=os.getenv("LANGFUSE_SECRET_KEY"), | |
| host=os.getenv("LANGFUSE_HOST", "https://cloud.langfuse.com") | |
| ) | |
| logger.info("Langfuse initialized successfully") | |
| except Exception as e: | |
| logger.warning(f"Failed to initialize Langfuse: {e}") | |
| langfuse = None | |
| # Initialize MemorySaver | |
| memory = MemorySaver() | |
| use_checkpointing = True | |
| # --- Constants --- | |
| DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space/api" | |
| GAIA_FILE_URL = "https://api.gaia-benchmark.com/files/" | |
| # --- Helper Functions --- | |
| def log_state(task_id: str, state: JARVISState): | |
| """Log intermediate state to state_log.json""" | |
| try: | |
| log_entry = { | |
| "task_id": task_id, | |
| "question": state["question"], | |
| "tools_needed": state["tools_needed"], | |
| "web_results": state["web_results"], | |
| "file_results": state["file_results"], | |
| "image_results": state["image_results"], | |
| "calculation_results": state["calculation_results"], | |
| "document_results": state["document_results"], | |
| "answer": state["answer"] | |
| } | |
| with open("state_log.json", "a") as f: | |
| json.dump(log_entry, f, indent=2) | |
| f.write("\n") | |
| except Exception as e: | |
| logger.error(f"Error logging state for task {task_id}: {e}") | |
| async def test_gaia_api(task_id: str) -> bool: | |
| """Test connectivity to GAIA file API""" | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.head(f"{GAIA_FILE_URL}{task_id}", timeout=5) as resp: | |
| return resp.status in [200, 403, 404] | |
| except Exception as e: | |
| logger.warning(f"GAIA API test failed: {e}") | |
| return False | |
| # --- Node Functions --- | |
| async def parse_question(state: JARVISState) -> JARVISState: | |
| try: | |
| question = state["question"] | |
| prompt = f"""Analyze this GAIA question: {question} | |
| Determine which tools are needed (web_search, multi_hop_search, file_parser, image_parser, calculator, document_retriever). | |
| Return a JSON list of tool names.""" | |
| if llm: | |
| response = await llm.ainvoke(prompt, config={"callbacks": [langfuse] if langfuse else []}) | |
| try: | |
| tools_needed = json.loads(response.content) | |
| except json.JSONDecodeError as je: | |
| logger.warning(f"Invalid JSON in LLM response for task {state['task_id']}: {je}") | |
| tools_needed = ["web_search"] | |
| else: | |
| logger.warning("No LLM available, using default tools") | |
| tools_needed = ["web_search"] | |
| state["tools_needed"] = tools_needed | |
| log_state(state["task_id"], state) | |
| return state | |
| except Exception as e: | |
| logger.error(f"Error parsing question for task {state['task_id']}: {e}") | |
| state["tools_needed"] = [] | |
| log_state(state["task_id"], state) | |
| return state | |
| async def tool_dispatcher(state: JARVISState) -> JARVISState: | |
| try: | |
| tools_needed = state["tools_needed"] | |
| updated_state = state.copy() | |
| can_download_files = await test_gaia_api(updated_state["task_id"]) | |
| for tool in tools_needed: | |
| try: | |
| if tool == "web_search" or tool == "multi_hop_search": | |
| result = await web_search_agent(updated_state) | |
| updated_state["web_results"].extend(result["web_results"]) | |
| elif tool == "file_parser" and can_download_files: | |
| result = await file_parser_agent(updated_state) | |
| updated_state["file_results"] = result["file_results"] | |
| elif tool == "image_parser" and can_download_files: | |
| result = await image_parser_agent(updated_state) | |
| updated_state["image_results"] = result["image_results"] | |
| elif tool == "calculator": | |
| result = await calculator_agent(updated_state) | |
| updated_state["calculation_results"] = result["calculation_results"] | |
| elif tool == "document_retriever" and can_download_files: | |
| result = await document_retriever_agent(updated_state) | |
| updated_state["document_results"] = result["document_results"] | |
| except Exception as e: | |
| logger.warning(f"Error in tool {tool} for task {updated_state['task_id']}: {e}") | |
| log_state(updated_state["task_id"], updated_state) | |
| return updated_state | |
| except Exception as e: | |
| logger.error(f"Error in tool dispatcher for task {state['task_id']}: {e}") | |
| log_state(state["task_id"], state) | |
| return state | |
| async def web_search_agent(state: JARVISState) -> JARVISState: | |
| try: | |
| results = [] | |
| if "web_search" in state["tools_needed"]: | |
| result = await search_tool.invoke({"query": state["question"]}) | |
| results.append(result) | |
| if "multi_hop_search" in state["tools_needed"]: | |
| result = await multi_hop_search_tool.invoke({"query": state["question"], "steps": 3}) | |
| results.append(result) | |
| return {"web_results": results} | |
| except Exception as e: | |
| logger.error(f"Error in web search for task {state['task_id']}: {e}") | |
| return {"web_results": []} | |
| async def file_parser_agent(state: JARVISState) -> JARVISState: | |
| try: | |
| if "file_parser" in state["tools_needed"]: | |
| file_type = "csv" if "data" in state["question"].lower() else "txt" | |
| result = await file_parser_tool.aparse(state["task_id"], file_type=file_type) | |
| return {"file_results": result} | |
| return {"file_results": ""} | |
| except Exception as e: | |
| logger.error(f"Error in file parser for task {state['task_id']}: {e}") | |
| return {"file_results": "File parsing failed"} | |
| async def image_parser_agent(state: JARVISState) -> JARVISState: | |
| try: | |
| if "image_parser" in state["tools_needed"]: | |
| task = "match" if "fruits" in state["question"].lower() else "describe" | |
| match_query = "fruits" if task == "match" else "" | |
| file_path = f"temp_{state['task_id']}.jpg" | |
| if not os.path.exists(file_path): | |
| logger.warning(f"Image file not found for task {state['task_id']}") | |
| return {"image_results": "Image file not found"} | |
| result = await image_parser_tool.aparse( | |
| file_path, task=task, match_query=match_query | |
| ) | |
| return {"image_results": result} | |
| return {"image_results": ""} | |
| except Exception as e: | |
| logger.error(f"Error in image parser for task {state['task_id']}: {e}") | |
| return {"image_results": "Image parsing failed"} | |
| async def calculator_agent(state: JARVISState) -> JARVISState: | |
| try: | |
| if "calculator" in state["tools_needed"]: | |
| prompt = f"Extract a mathematical expression from: {state['question']}\n{state['file_results']}" | |
| if llm: | |
| response = await llm.ainvoke(prompt, config={"callbacks": [langfuse] if langfuse else []}) | |
| expression = response.content | |
| else: | |
| expression = "0" | |
| result = await calculator_tool.aparse(expression) | |
| return {"calculation_results": result} | |
| return {"calculation_results": ""} | |
| except Exception as e: | |
| logger.error(f"Error in calculator for task {state['task_id']}: {e}") | |
| return {"calculation_results": "Calculation failed"} | |
| async def document_retriever_agent(state: JARVISState) -> JARVISState: | |
| try: | |
| if "document_retriever" in state["tools_needed"]: | |
| file_type = "txt" if "menu" in state["question"].lower() else "csv" | |
| if "report" in state["question"].lower() or "document" in state["question"].lower(): | |
| file_type = "pdf" | |
| result = await document_retriever_tool.aparse( | |
| state["task_id"], state["question"], file_type=file_type | |
| ) | |
| return {"document_results": result} | |
| return {"document_results": ""} | |
| except Exception as e: | |
| logger.error(f"Error in document retriever for task {state['task_id']}: {e}") | |
| return {"document_results": "Document retrieval failed"} | |
| async def reasoning_agent(state: JARVISState) -> JARVISState: | |
| try: | |
| prompt = f"""Question: {state['question']} | |
| Web Results: {state['web_results']} | |
| File Results: {state['file_results']} | |
| Image Results: {state['image_results']} | |
| Calculation Results: {state['calculation_results']} | |
| Document Results: {state['document_results']} | |
| Synthesize an exact-match answer for the GAIA benchmark. | |
| Output only the answer (e.g., '90', 'White;5876').""" | |
| if llm: | |
| response = await llm.ainvoke( | |
| [ | |
| SystemMessage(content="You are JARVIS, a precise assistant for the GAIA benchmark. Provide exact answers only."), | |
| HumanMessage(content=prompt) | |
| ], | |
| config={"callbacks": [langfuse] if langfuse else []} | |
| ) | |
| answer = response.content.strip() | |
| else: | |
| answer = "Unknown" | |
| state["answer"] = answer | |
| log_state(state["task_id"], state) | |
| return state | |
| except Exception as e: | |
| logger.error(f"Error in reasoning for task {state['task_id']}: {e}") | |
| state["answer"] = "Error in reasoning" | |
| log_state(state["task_id"], state) | |
| return state | |
| def router(state: JARVISState) -> str: | |
| if state["tools_needed"]: | |
| return "tool_dispatcher" | |
| return "reasoning" | |
| # --- Define StateGraph --- | |
| workflow = StateGraph(JARVISState) | |
| workflow.add_node("parse", parse_question) | |
| workflow.add_node("tool_dispatcher", tool_dispatcher) | |
| workflow.add_node("reasoning", reasoning_agent) | |
| workflow.set_entry_point("parse") | |
| workflow.add_conditional_edges( | |
| "parse", | |
| router, | |
| { | |
| "tool_dispatcher": "tool_dispatcher", | |
| "reasoning": "reasoning" | |
| } | |
| ) | |
| workflow.add_edge("tool_dispatcher", "reasoning") | |
| workflow.add_edge("reasoning", END) | |
| # Compile graph | |
| graph = workflow.compile(checkpointer=memory if use_checkpointing else None) | |
| # --- Basic Agent Definition --- | |
| class BasicAgent: | |
| def __init__(self): | |
| logger.info("BasicAgent initialized.") | |
| async def process_question(self, task_id: str, question: str) -> str: | |
| file_type = "jpg" if "image" in question.lower() else "txt" | |
| if "menu" in question.lower() or "report" in question.lower() or "document" in question.lower(): | |
| file_type = "pdf" | |
| elif "data" in question.lower(): | |
| file_type = "csv" | |
| file_path = f"temp_{task_id}.{file_type}" | |
| if await test_gaia_api(task_id): | |
| try: | |
| async with aiohttp.ClientSession() as session: | |
| async with session.get(f"{GAIA_FILE_URL}{task_id}") as resp: | |
| if resp.status == 200: | |
| with open(file_path, "wb") as f: | |
| f.write(await resp.read()) | |
| else: | |
| logger.warning(f"Failed to download file for task {task_id}: HTTP {resp.status}") | |
| except Exception as e: | |
| logger.error(f"Error downloading file for task {task_id}: {e}") | |
| state = JARVISState( | |
| task_id=task_id, | |
| question=question, | |
| tools_needed=[], | |
| web_results=[], | |
| file_results="", | |
| image_results="", | |
| calculation_results="", | |
| document_results="", | |
| messages=[], | |
| answer="" | |
| ) | |
| try: | |
| config = {"configurable": {"thread_id": task_id}} if use_checkpointing else {} | |
| result = await graph.ainvoke(state, config=config) | |
| return result["answer"] or "No answer generated" | |
| except Exception as e: | |
| logger.error(f"Error processing task {task_id}: {e}") | |
| return f"Error: {str(e)}" | |
| finally: | |
| if os.path.exists(file_path): | |
| try: | |
| os.remove(file_path) | |
| except Exception as e: | |
| logger.error(f"Error removing file {file_path}: {e}") | |
| async def async_call(self, question: str, task_id: str) -> str: | |
| return await self.process_question(task_id, question) | |
| def __call__(self, question: str, task_id: str = None) -> str: | |
| logger.info(f"Agent received question (first 50 chars): {question[:50]}...") | |
| if task_id is None: | |
| logger.warning("task_id not provided, using placeholder") | |
| task_id = "placeholder_task_id" | |
| try: | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| return loop.run_until_complete(self.async_call(question, task_id)) | |
| finally: | |
| pass | |
| # --- Main Function --- | |
| def run_and_submit_all(profile: gr.OAuthProfile | None): | |
| space_id = os.getenv("SPACE_ID") | |
| if not profile: | |
| logger.error("User not logged in.") | |
| return "Please Login to Hugging Face with the button.", None | |
| username = f"{profile.username}" | |
| logger.info(f"User logged in: {username}") | |
| api_url = DEFAULT_API_URL | |
| questions_url = f"{api_url}/questions" | |
| submit_url = f"{api_url}/submit" | |
| agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" | |
| try: | |
| agent = BasicAgent() | |
| except Exception as e: | |
| logger.error(f"Error instantiating agent: {e}") | |
| return f"Error initializing agent: {e}", None | |
| logger.info(f"Fetching questions from: {questions_url}") | |
| try: | |
| response = requests.get(questions_url, timeout=15) | |
| response.raise_for_status() | |
| questions_data = response.json() | |
| if not questions_data: | |
| logger.error("Fetched questions list is empty.") | |
| return "Fetched questions list is empty or invalid format.", None | |
| logger.info(f"Fetched {len(questions_data)} questions.") | |
| except Exception as e: | |
| logger.error(f"Error fetching questions: {e}") | |
| return f"Error fetching questions: {e}", None | |
| results_log = [] | |
| answers_payload = [] | |
| logger.info(f"Running agent on {len(questions_data)} questions...") | |
| for item in questions_data: | |
| task_id = item.get("task_id") | |
| question_text = item.get("question") | |
| if not task_id or question_text is None: | |
| logger.warning(f"Skipping item with missing task_id or question: {item}") | |
| continue | |
| try: | |
| submitted_answer = agent(question_text, task_id) | |
| answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer}) | |
| results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer}) | |
| except Exception as e: | |
| logger.error(f"Error running agent on task {task_id}: {e}") | |
| results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"}) | |
| if not answers_payload: | |
| logger.error("Agent did not produce any answers to submit.") | |
| return "Agent did not produce any answers to submit.", pd.DataFrame(results_log) | |
| submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload} | |
| logger.info(f"Submitting {len(answers_payload)} answers to: {submit_url}") | |
| try: | |
| response = requests.post(submit_url, json=submission_data, timeout=120) | |
| response.raise_for_status() | |
| result_data = response.json() | |
| logger.info(f"Server response: {result_data}") | |
| final_status = ( | |
| f"Submission Successful!\n" | |
| f"User: {result_data.get('username')}\n" | |
| f"Overall Score: {result_data.get('score', 'N/A')}% " | |
| f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n" | |
| f"Message: {result_data.get('message', 'No message received.')}" | |
| ) | |
| results_df = pd.DataFrame(results_log) | |
| return final_status, results_df | |
| except Exception as e: | |
| logger.error(f"Submission failed: {e}") | |
| results_df = pd.DataFrame(results_log) | |
| return f"Submission Failed: {e}", results_df | |
| # --- Build Gradio Interface --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# JARVIS Agent Evaluation Runner") | |
| gr.Markdown( | |
| """ | |
| **Instructions:** | |
| 1. Log in to your Hugging Face account using the button below. | |
| 2. Click 'Run Evaluation & Submit All Answers' to fetch questions, run the JARVIS agent, and submit answers. | |
| --- | |
| **Disclaimers:** | |
| The agent uses a local Hugging Face model (Mixtral-7B) and async tools for the GAIA benchmark. | |
| """ | |
| ) | |
| gr.LoginButton() | |
| run_button = gr.Button("Run Evaluation & Submit All Answers") | |
| status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False) | |
| results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True) | |
| run_button.click( | |
| fn=run_and_submit_all, | |
| outputs=[status_output, results_table] | |
| ) | |
| if __name__ == "__main__": | |
| logger.info("\n" + "-"*30 + " App Starting " + "-"*30) | |
| space_id = os.getenv("SPACE_ID") | |
| logger.info(f"SPACE_ID: {space_id}") | |
| logger.info("Launching Gradio Interface...") | |
| demo.launch(debug=True, share=False) |