File size: 19,635 Bytes
40ddda8
bda2844
 
4985a8e
bda2844
40ddda8
4985a8e
 
bda2844
0230184
bda2844
 
 
4985a8e
 
011ec37
2da0ef4
44d380b
4985a8e
bda2844
4985a8e
bda2844
 
4985a8e
 
 
 
 
 
 
 
bda2844
4985a8e
2da0ef4
 
 
 
 
 
 
689cba9
2da0ef4
689cba9
2da0ef4
4985a8e
 
 
 
 
 
 
362c28f
44d380b
4985a8e
 
 
45fb8fa
 
 
44d380b
 
45fb8fa
44d380b
 
45fb8fa
44d380b
45fb8fa
 
 
362c28f
44d380b
45fb8fa
44d380b
 
45fb8fa
44d380b
45fb8fa
 
 
44d380b
45fb8fa
44d380b
4985a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94d2b35
4985a8e
 
 
 
 
45fb8fa
 
 
 
4985a8e
 
 
 
 
 
 
 
45fb8fa
 
 
 
 
 
 
 
 
4985a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40ddda8
44d380b
 
 
 
 
 
 
 
 
 
30a99e4
4985a8e
40ddda8
689cba9
4985a8e
 
 
 
 
 
 
 
 
45fb8fa
4985a8e
45fb8fa
 
 
44d380b
 
 
45fb8fa
 
 
44d380b
 
45fb8fa
 
44d380b
 
45fb8fa
 
 
 
 
 
 
 
 
44d380b
4985a8e
 
45fb8fa
 
 
 
 
 
 
 
 
 
30a99e4
4985a8e
 
 
40ddda8
4985a8e
 
30a99e4
4985a8e
40ddda8
4985a8e
45fb8fa
4985a8e
 
 
 
 
 
bda2844
b108fbc
bda2844
 
 
40ddda8
 
4985a8e
 
45fb8fa
4985a8e
bda2844
 
 
40ddda8
45fb8fa
40ddda8
45fb8fa
30a99e4
4985a8e
 
b108fbc
4985a8e
 
bda2844
 
 
4985a8e
40ddda8
 
bda2844
b108fbc
 
bda2844
 
 
 
 
b108fbc
 
 
bda2844
 
 
 
 
 
 
 
 
 
 
4985a8e
bda2844
 
 
 
b108fbc
 
40ddda8
bda2844
 
 
40ddda8
bda2844
4985a8e
bda2844
 
 
 
 
40ddda8
 
 
bda2844
 
 
 
 
40ddda8
bda2844
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40ddda8
bda2844
 
 
 
 
40ddda8
4985a8e
40ddda8
 
 
4985a8e
 
 
 
45fb8fa
4985a8e
 
45fb8fa
 
40ddda8
4985a8e
 
 
45fb8fa
40ddda8
 
 
 
4985a8e
45fb8fa
 
 
 
 
 
 
40ddda8
4985a8e
40ddda8
 
 
 
 
45fb8fa
40ddda8
 
 
 
 
 
b108fbc
4985a8e
40ddda8
 
 
 
bda2844
4985a8e
b108fbc
40ddda8
bda2844
 
40ddda8
bda2844
4985a8e
40ddda8
4985a8e
 
2da0ef4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
import os
import gradio as gr
import requests
import inspect
import pandas as pd
from dotenv import load_dotenv
from typing import TypedDict, Annotated, Sequence, List, Dict, Any, Optional
import operator

from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.tools.wikipedia.tool import WikipediaQueryRun
from langchain_community.utilities.wikipedia import WikipediaAPIWrapper
from langchain_community.tools.arxiv.tool import ArxivQueryRun
from langchain_community.utilities.arxiv import ArxivAPIWrapper

from langgraph.graph import StateGraph, END

from langchain_core.messages import BaseMessage, FunctionMessage, HumanMessage, AIMessage, SystemMessage
from langchain_openai import ChatOpenAI

# --- Constants ---
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"

# --- Environment Setup ---
load_dotenv()

OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") # Assuming Tavily might also need an API key

if not OPENROUTER_API_KEY:
    print("Warning: OPENROUTER_API_KEY not found in .env file. The LLM will not function.")

# --- Tool Setup ---
tools = []
if TAVILY_API_KEY:
    tavily_tool = TavilySearchResults(max_results=3, api_key=TAVILY_API_KEY)
    tools.append(tavily_tool)
else:
    print("Warning: TAVILY_API_KEY not found in .env file. TavilySearchResults tool will not be available.")

wikipedia_tool = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper(top_k_results=10, doc_content_chars_max=2000))
tools.append(wikipedia_tool)
arxiv_tool = ArxivQueryRun(api_wrapper=ArxivAPIWrapper(top_k_results=10, doc_content_chars_max=2000))
tools.append(arxiv_tool)

# --- LangGraph Agent Definition ---
class AgentState(TypedDict):
    messages: Annotated[Sequence[BaseMessage], operator.add]
    next_action: Optional[str] # To decide if we need to call tools or respond

class LangGraphAgent:
    def __init__(self, llm_choice: str = "qwen"):
        print(f"LangGraphAgent initializing with {llm_choice}...")
        if not OPENROUTER_API_KEY:
            raise ValueError("OPENROUTER_API_KEY is not set. Cannot initialize LLM.")

        self.llm_choice = llm_choice
        self.supports_tool_calling = False # Default to false

        if llm_choice == "llama":
            self.llm = ChatOpenAI(
                model="meta-llama/llama-3.1-8b-instruct:free", # Corrected to Llama 3.1 as per user's earlier request
                api_key=OPENROUTER_API_KEY,
                base_url="https://openrouter.ai/api/v1",
                temperature=0.1,
            )
            # Llama 3.1 8B on OpenRouter might not support tool calling via the OpenAI SDK binding method
            self.supports_tool_calling = False 
            print("Initialized Llama 3.1 8B Instruct (tool calling assumed NOT supported).")
        elif llm_choice == "qwen":
            self.llm = ChatOpenAI(
                model="qwen/qwen-2-7b-instruct:free", # Using a Qwen-2 model as qwq-32b might be older
                api_key=OPENROUTER_API_KEY,
                base_url="https://openrouter.ai/api/v1",
                temperature=0.1
            )
            # Qwen models on OpenRouter might not support tool calling via the OpenAI SDK binding method
            self.supports_tool_calling = False
            print("Initialized Qwen-2 7B Instruct (tool calling assumed NOT supported).")
        else:
            raise ValueError(f"Unsupported LLM choice: {llm_choice}. Choose 'llama', or 'qwen'.")

        self.tools_map = {tool.name: tool for tool in tools}
        self.graph = self._build_graph()
        print("LangGraphAgent initialized.")

    def _build_graph(self):
        workflow = StateGraph(AgentState)

        workflow.add_node("llm", self._call_llm)
        workflow.add_node("tools", self._tool_node)

        workflow.set_entry_point("llm")

        workflow.add_conditional_edges(
            "llm",
            self._should_call_tools,
            {
                "continue": "tools",
                "end": END
            }
        )
        workflow.add_edge("tools", "llm")
        return workflow.compile()

    def _should_call_tools(self, state: AgentState) -> str:
        print("LLM deciding next step...")
        if not self.supports_tool_calling:
            print("Tool calling not supported by the current LLM. Ending interaction.")
            return "end"

        last_message = state["messages"][-1]
        if hasattr(last_message, "tool_calls") and last_message.tool_calls:
            print(f"LLM decided to call tools: {last_message.tool_calls}")
            return "continue"
        print("LLM decided to end.")
        return "end"

    def _call_llm(self, state: AgentState) -> Dict[str, Any]:
        print(f"Calling LLM ({self.llm_choice})...")
        if self.supports_tool_calling:
            print("Binding tools to LLM for function calling.")
            llm_with_tools = self.llm.bind_tools(tools)
            response = llm_with_tools.invoke(state["messages"])
        else:
            print("Invoking LLM without binding tools.")
            response = self.llm.invoke(state["messages"])
        
        print(f"LLM response: {response.content[:100]}...")
        return {"messages": [response]}

    def _tool_node(self, state: AgentState) -> Dict[str, Any]:
        print("Executing tools...")
        tool_messages = []
        last_message = state["messages"][-1]

        if not hasattr(last_message, "tool_calls") or not last_message.tool_calls:
            print("No tool calls found in the last message.")
            # This case should ideally be handled by the conditional edge, but as a fallback:
            return {"messages": [AIMessage(content="No tools to call, proceeding.")]}

        for tool_call in last_message.tool_calls:
            tool_name = tool_call["name"]
            tool_args = tool_call["args"]
            print(f"Calling tool: {tool_name} with args: {tool_args}")
            if tool_name in self.tools_map:
                try:
                    tool_result = self.tools_map[tool_name].invoke(tool_args)
                    print(f"Tool {tool_name} result (first 100 chars): {str(tool_result)[:100]}...")
                    tool_messages.append(FunctionMessage(content=str(tool_result), name=tool_name, tool_call_id=tool_call["id"]))
                except Exception as e:
                    print(f"Error executing tool {tool_name}: {e}")
                    tool_messages.append(FunctionMessage(content=f"Error executing tool {tool_name}: {e}", name=tool_name, tool_call_id=tool_call["id"]))
            else:
                print(f"Tool {tool_name} not found.")
                tool_messages.append(FunctionMessage(content=f"Tool {tool_name} not found.", name=tool_name, tool_call_id=tool_call["id"]))
        return {"messages": tool_messages}

    def __call__(self, question: str) -> str:
        print(f"Agent received question (first 100 chars): {question[:100]}...")
        
        system_prompt = (
            "You are an AI assistant designed to answer questions concisely. "
            "Your goal is to provide only the direct answer to the question, without any additional explanations, conversation, or prefixes like 'FINAL ANSWER:'. "
            "For example, if the question is 'What is the capital of France?', you should respond with 'Paris'. "
            "If the question asks for a list, provide it comma-separated, e.g., 'apple, banana, cherry'. "
            "If the question asks for a number, provide only the number, e.g., '42'."
        )
        initial_state = {"messages": [SystemMessage(content=system_prompt), HumanMessage(content=question)]}
        
        final_graph_state = None
        try:
            for event in self.graph.stream(initial_state, {"recursion_limit": 100}): # Added recursion limit
                if END in event:
                    final_graph_state = event[END]
                    break
                for key in event:
                    if key != END:
                        final_graph_state = event[key]

            if final_graph_state and final_graph_state["messages"]:
                for msg in reversed(final_graph_state["messages"]):
                    if isinstance(msg, AIMessage) and not msg.tool_calls and msg.content: # Ensure content exists
                        answer = msg.content.strip()
                        if not answer: # Skip empty answers after initial stripping
                            continue

                        # Remove common prefixes that LLMs might add despite instructions
                        prefixes_to_remove = [
                            "FINAL ANSWER:", "The answer is", "Here is the answer:", 
                            "The final answer is", "Answer:", "Solution:",
                            "The direct answer is", "Here's the concise answer:",
                            "Here you go:", "Certainly, the answer is"
                        ]
                        for prefix in prefixes_to_remove:
                            # Case-insensitive prefix removal
                            if answer.lower().startswith(prefix.lower()):
                                answer = answer[len(prefix):].strip()
                        
                        # More robust quote stripping
                        if answer.startswith(("\"", "'")) and answer.endswith(("\"", "'")):
                            temp_answer = answer[1:-1]
                            # Avoid stripping if it's a legitimately quoted string like "'quoted string'" as the answer itself
                            if not (temp_answer.startswith(("\"", "'")) and temp_answer.endswith(("\"", "'"))):
                                answer = temp_answer
                        
                        if not answer: # Check again if answer became empty after stripping
                            continue

                        print(f"Agent returning answer: {answer}")
                        return answer
                
                # Refined fallback logic
                print("No suitable AI message with valid content found after processing. Attempting to return last raw AI message if available.")
                last_ai_msg_content = next((m.content.strip() for m in reversed(final_graph_state["messages"]) if isinstance(m, AIMessage) and m.content and not m.tool_calls), None)
                if last_ai_msg_content:
                     print(f"Agent returning last raw AI message as fallback: {last_ai_msg_content}")
                     return last_ai_msg_content
                
                print("No suitable AI message found for final answer, even as fallback.")
                return "Error: Agent could not extract a valid answer." # More specific error
            else:
                print("Error: Agent did not reach a final state or no messages found.")
                return "Error: Agent did not produce a conclusive answer."

        except Exception as e:
            print(f"Error during agent execution: {e}")
            import traceback
            traceback.print_exc()
            return f"Error during agent execution: {e}"

# --- Main Evaluation Logic (Modified from starter) ---
def run_and_submit_all(profile: gr.OAuthProfile | None, llm_model_choice: str):
    """
    Fetches all questions, runs the LangGraphAgent on them, submits all answers,
    and displays the results.
    """
    space_id = os.getenv("SPACE_ID")

    if profile:
        username = f"{profile.username}"
        print(f"User logged in: {username}")
    else:
        print("User not logged in.")
        return "Please Login to Hugging Face with the button.", None

    if not OPENROUTER_API_KEY:
         return "Error: OPENROUTER_API_KEY not found. Please set it in your .env file.", None
    # TAVILY_API_KEY check is handled by the tool initialization itself with a warning.

    api_url = DEFAULT_API_URL
    questions_url = f"{api_url}/questions"
    submit_url = f"{api_url}/submit"

    print(f"Attempting to initialize agent with LLM: {llm_model_choice}")
    try:
        agent = LangGraphAgent(llm_choice=llm_model_choice) 
    except Exception as e:
        print(f"Error instantiating agent: {e}")
        return f"Error initializing agent: {e}", None

    agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main" if space_id else "local_run_no_space_id"
    print(f"Agent code link: {agent_code}")

    print(f"Fetching questions from: {questions_url}")
    try:
        response = requests.get(questions_url, timeout=20)
        response.raise_for_status()
        questions_data = response.json()
        if not questions_data:
            print("Fetched questions list is empty.")
            return "Fetched questions list is empty or invalid format.", None
        print(f"Fetched {len(questions_data)} questions.")
    except requests.exceptions.RequestException as e:
        print(f"Error fetching questions: {e}")
        return f"Error fetching questions: {e}", None
    except requests.exceptions.JSONDecodeError as e:
        print(f"Error decoding JSON response from questions endpoint: {e}")
        print(f"Response text: {response.text[:500]}")
        return f"Error decoding server response for questions: {e}", None

    results_log = []
    answers_payload = []
    print(f"Running agent on {len(questions_data)} questions...")
    for item in questions_data:
        task_id = item.get("task_id")
        question_text = item.get("question")
        if not task_id or question_text is None:
            print(f"Skipping item with missing task_id or question: {item}")
            continue
        try:
            print(f"\n--- Processing Task ID: {task_id} ---")
            submitted_answer = agent(question_text)
            answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
            results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
        except Exception as e:
            print(f"Error running agent on task {task_id}: {e}")
            results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})

    if not answers_payload:
        print("Agent did not produce any answers to submit.")
        return "Agent did not produce any answers to submit.", pd.DataFrame(results_log)

    submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
    status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user 	'{username}	'..."
    print(status_update)

    print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
    try:
        response = requests.post(submit_url, json=submission_data, timeout=60)
        response.raise_for_status()
        result_data = response.json()
        final_status = (
            f"Submission Successful!\n"
            f"User: {result_data.get('username')}\n"
            f"Overall Score: {result_data.get('score', 'N/A')}% "
            f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
            f"Message: {result_data.get('message', 'No message received.')}"
        )
        print("Submission successful.")
        results_df = pd.DataFrame(results_log)
        return final_status, results_df
    except requests.exceptions.HTTPError as e:
        error_detail = f"Server responded with status {e.response.status_code}."
        try:
            error_json = e.response.json()
            error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
        except requests.exceptions.JSONDecodeError:
            error_detail += f" Response: {e.response.text[:500]}"
        status_message = f"Submission Failed: {error_detail}"
        print(status_message)
        results_df = pd.DataFrame(results_log)
        return status_message, results_df
    except requests.exceptions.Timeout:
        status_message = "Submission Failed: The request timed out."
        print(status_message)
        results_df = pd.DataFrame(results_log)
        return status_message, results_df
    except requests.exceptions.RequestException as e:
        status_message = f"Submission Failed: Network error - {e}"
        print(status_message)
        results_df = pd.DataFrame(results_log)
        return status_message, results_df
    except Exception as e:
        status_message = f"An unexpected error occurred during submission: {e}"
        print(status_message)
        results_df = pd.DataFrame(results_log)
        return status_message, results_df

with gr.Blocks() as demo:
    gr.Markdown("# LangGraph GAIA Agent Evaluation Runner")
    gr.Markdown(
        """
        **Instructions:**
        1.  **Clone this space** if you haven't already.
        2.  **Create a `.env` file** in the root of your space with your API keys:
            ```
            OPENROUTER_API_KEY="your_openrouter_api_key"
            TAVILY_API_KEY="your_tavily_api_key" # Optional, but TavilySearch tool won't work without it
            ```
        3.  Log in to your Hugging Face account using the button below. This uses your HF username for submission.
        4.  **Select the LLM model** you want the agent to use.
        5.  Click 'Run Evaluation & Submit All Answers' to fetch questions, run your agent, submit answers, and see the score.
        ---
        **Disclaimers:**
        -   Ensure your Hugging Face Space is public for the `agent_code` link to be verifiable.
        -   Submitting all answers can take some time as the agent processes each question.
        -   The agent will use the selected LLM. Note that only some models (e.g., llama) support tool/function calling. If a model without tool support is chosen for a task requiring tools, it may not perform optimally or might not use tools.
        """
    )

    gr.LoginButton()

    llm_choice_dropdown = gr.Dropdown(
        choices=["llama", "qwen"],
        value="llama", # Default to llama as it supports tool calling
        label="Select LLM Model",
        info="Choose the Large Language Model for the agent."
    )

    run_button = gr.Button("Run Evaluation & Submit All Answers")

    status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
    results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)

    run_button.click(
        fn=run_and_submit_all,
        inputs=[llm_choice_dropdown], # Add llm_choice_dropdown as an input
        outputs=[status_output, results_table]
    )

if __name__ == "__main__":
    print("\n" + "-"*30 + " App Starting " + "-"*30)
    space_host_startup = os.getenv("SPACE_HOST")
    space_id_startup = os.getenv("SPACE_ID")

    if space_host_startup:
        print(f"✅ SPACE_HOST found: {space_host_startup}")
        print(f"   Runtime URL should be: https://{space_host_startup}.hf.space")
    else:
        print("ℹ️  SPACE_HOST environment variable not found (running locally?).")

    if space_id_startup:
        print(f"✅ SPACE_ID found: {space_id_startup}")
        print(f"   Repo URL: https://huggingface.co/spaces/{space_id_startup}")
        print(f"   Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
    else:
        print("ℹ️  SPACE_ID environment variable not found (running locally?). Repo URL cannot be determined.")

    print("-"*(60 + len(" App Starting ")) + "\n")

    print("Launching Gradio Interface for LangGraph GAIA Agent Evaluation...")
    demo.launch(debug=True, share=False)