gabejavitt commited on
Commit
4277297
·
verified ·
1 Parent(s): 8a7fdce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +535 -193
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import io
4
  import json
@@ -8,11 +7,12 @@ import contextlib
8
  import uuid
9
  import time
10
  import ast
11
- from typing import List, Optional, TypedDict, Annotated
12
  from pathlib import Path
 
13
 
14
- import gradio as gr
15
  import pandas as pd
 
16
  import torch
17
  from pydantic import BaseModel, Field
18
 
@@ -41,11 +41,12 @@ from langchain_core.documents import Document
41
  # CONFIGURATION
42
  # =============================================================================
43
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
44
- MAX_TURNS = 20
45
  MAX_MESSAGE_LENGTH = 8000
 
46
 
47
  # =============================================================================
48
- # GLOBAL RAG COMPONENTS (Initialize once)
49
  # =============================================================================
50
  global_embeddings = None
51
  global_text_splitter = None
@@ -138,7 +139,138 @@ def find_file(path: str) -> Optional[Path]:
138
  return None
139
 
140
  # =============================================================================
141
- # TOOL DEFINITIONS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  # =============================================================================
143
 
144
  class SearchInput(BaseModel):
@@ -146,11 +278,19 @@ class SearchInput(BaseModel):
146
 
147
  @tool(args_schema=SearchInput)
148
  def search_tool(query: str) -> str:
149
- """Calls DuckDuckGo search and returns the results. Use this for recent information or general web searches."""
 
 
 
 
 
 
 
 
150
  if not isinstance(query, str) or not query.strip():
151
  return "Error: Invalid input. 'query' must be a non-empty string."
152
 
153
- print(f"--- Calling Search Tool with query: {query} ---")
154
  try:
155
  search = DuckDuckGoSearchRun()
156
  result = search.run(query)
@@ -161,25 +301,70 @@ def search_tool(query: str) -> str:
161
  return f"Error running search for '{query}': {str(e)}"
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  class CodeInput(BaseModel):
165
- code: str = Field(description="The Python code to execute, which must include a print() statement for output.")
166
 
167
  @tool(args_schema=CodeInput)
168
  def code_interpreter(code: str) -> str:
169
  """
170
- Executes a string of Python code and returns its stdout, stderr, and any error.
171
- CRITICAL RULES:
172
- 1. ALWAYS use print() to output your final answer.
173
- 2. Write simple, focused code. One task per execution.
174
- 3. Add comments (#) to explain your logic.
175
- 4. SCOPE RULE: Import all necessary libraries inside any function you define.
176
- Available: pandas as pd, basic Python libraries.
 
 
 
 
 
 
 
 
177
  """
178
  if not isinstance(code, str):
179
  return "Error: Invalid input. 'code' must be a string."
180
 
181
- # Basic safety checks
182
- dangerous_patterns = ['__import__', 'eval(', 'compile(', 'subprocess', 'os.system']
183
  code_lower = code.lower()
184
  for pattern in dangerous_patterns:
185
  if pattern in code_lower:
@@ -188,7 +373,7 @@ def code_interpreter(code: str) -> str:
188
  if 'open(' in code_lower and any(mode in code for mode in ["'w'", '"w"', "'a'", '"a"', "'wb'", '"wb"']):
189
  return "Error: Writing files is not allowed in code_interpreter. Use write_file tool instead."
190
 
191
- print(f"--- Calling Code Interpreter ---\nCode:\n{code}\n---")
192
  output_stream = io.StringIO()
193
  error_stream = io.StringIO()
194
 
@@ -196,6 +381,9 @@ def code_interpreter(code: str) -> str:
196
  with contextlib.redirect_stdout(output_stream), contextlib.redirect_stderr(error_stream):
197
  safe_globals = {
198
  "pd": pd,
 
 
 
199
  "__builtins__": __builtins__
200
  }
201
  exec(code, safe_globals, {})
@@ -209,9 +397,9 @@ def code_interpreter(code: str) -> str:
209
  if stdout:
210
  if len(stdout) > MAX_MESSAGE_LENGTH:
211
  stdout = stdout[:MAX_MESSAGE_LENGTH] + f"\n...[truncated, {len(stdout)} total chars]"
212
- return f"Success:\n{stdout}"
213
 
214
- return "Success: Code executed without error but produced no output.\n⚠️ Remember to use print() to output your results!"
215
 
216
  except Exception as e:
217
  tb_str = traceback.format_exc()
@@ -219,15 +407,15 @@ def code_interpreter(code: str) -> str:
219
 
220
 
221
  class ReadFileInput(BaseModel):
222
- path: str = Field(description="The path to the file to read.")
223
 
224
  @tool(args_schema=ReadFileInput)
225
  def read_file(path: str) -> str:
226
- """Reads the content of a file at the specified path."""
227
  if not isinstance(path, str) or not path.strip():
228
  return "Error: Invalid input. 'path' must be a non-empty string."
229
 
230
- print(f"--- Calling Read File Tool: {path} ---")
231
 
232
  file_path = find_file(path)
233
  if not file_path:
@@ -249,18 +437,18 @@ def read_file(path: str) -> str:
249
 
250
 
251
  class WriteFileInput(BaseModel):
252
- path: str = Field(description="The path of the file to write to.")
253
- content: str = Field(description="The content to write into the file.")
254
 
255
  @tool(args_schema=WriteFileInput)
256
  def write_file(path: str, content: str) -> str:
257
- """Writes content to a file at the specified path."""
258
  if not isinstance(path, str) or not path.strip():
259
  return "Error: Invalid input. 'path' must be a non-empty string."
260
  if not isinstance(content, str):
261
  return "Error: Invalid input. 'content' must be a string."
262
 
263
- print(f"--- Calling Write File Tool: {path} ---")
264
 
265
  try:
266
  file_path = Path.cwd() / path
@@ -272,12 +460,12 @@ def write_file(path: str, content: str) -> str:
272
 
273
 
274
  class ListDirInput(BaseModel):
275
- path: str = Field(description="The directory path to list.", default=".")
276
 
277
  @tool(args_schema=ListDirInput)
278
  def list_directory(path: str = ".") -> str:
279
- """Lists the contents of a directory."""
280
- print(f"--- Calling List Directory Tool: {path} ---")
281
 
282
  try:
283
  dir_path = Path.cwd() / path if path != "." else Path.cwd()
@@ -311,15 +499,15 @@ def list_directory(path: str = ".") -> str:
311
 
312
 
313
  class AudioInput(BaseModel):
314
- file_path: str = Field(description="The file path of the audio to transcribe.")
315
 
316
  @tool(args_schema=AudioInput)
317
  def audio_transcription_tool(file_path: str) -> str:
318
- """Transcribes an audio file to text using Whisper."""
319
  if not isinstance(file_path, str) or not file_path.strip():
320
  return "Error: Invalid input. 'file_path' must be a non-empty string."
321
 
322
- print(f"--- Calling Audio Transcription: {file_path} ---")
323
 
324
  if asr_pipeline is None:
325
  return "Error: ASR pipeline is not available."
@@ -339,17 +527,16 @@ def audio_transcription_tool(file_path: str) -> str:
339
  except Exception as e:
340
  return f"Error transcribing '{file_path}': {str(e)}"
341
 
342
-
343
  class YoutubeInput(BaseModel):
344
- video_url: str = Field(description="The URL of the YouTube video.")
345
 
346
  @tool(args_schema=YoutubeInput)
347
  def get_youtube_transcript(video_url: str) -> str:
348
- """Fetches the transcript/captions for a YouTube video."""
349
  if not isinstance(video_url, str) or not video_url.strip():
350
  return "Error: Invalid input. 'video_url' must be a non-empty string."
351
 
352
- print(f"--- Calling YouTube Transcript: {video_url} ---")
353
 
354
  try:
355
  video_id = None
@@ -373,94 +560,76 @@ def get_youtube_transcript(video_url: str) -> str:
373
 
374
 
375
  class ScrapeInput(BaseModel):
376
- url: str = Field(description="The URL to scrape (must start with http:// or https://).")
377
- query: str = Field(description="The specific question to answer or information to find on the page.")
378
 
379
  @tool(args_schema=ScrapeInput)
380
  def scrape_and_retrieve(url: str, query: str) -> str:
381
  """
382
- Scrapes a webpage, embeds its content using RAG, and retrieves relevant sections based on the query.
383
- Use this to extract specific information from web pages.
 
 
 
 
384
  """
385
  if not (url.lower().startswith(('http://', 'https://'))):
386
  return f"Error: Invalid URL. Must start with http:// or https://. Got: '{url}'"
387
  if not query or not query.strip():
388
  return "Error: A query is required to search the page content."
389
 
390
- # Check if RAG components are initialized
391
  if global_embeddings is None or global_text_splitter is None:
392
  if not initialize_rag_components():
393
  return "Error: RAG components could not be initialized."
394
 
395
- print(f"--- Calling RAG Scraper: {url} for query: '{query}' ---")
396
 
397
  try:
398
- # Fetch the webpage
399
  headers = {
400
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
401
  }
402
- print(f"Fetching URL: {url}")
403
  response = requests.get(url, headers=headers, timeout=20)
404
  response.raise_for_status()
405
 
406
- # Parse HTML
407
  soup = BeautifulSoup(response.text, 'html.parser')
408
 
409
- # Remove unwanted tags
410
  for tag in soup(["script", "style", "nav", "footer", "aside", "header", "iframe", "noscript"]):
411
  tag.extract()
412
 
413
- # Try to find main content
414
  main_content = soup.find('main') or soup.find('article') or soup.find('div', class_=re.compile('content|main|article', re.I)) or soup.body
415
 
416
  if not main_content:
417
  return "Error: Could not find main content on the page."
418
 
419
- # Extract text
420
  text = main_content.get_text(separator='\n', strip=True)
421
-
422
- # Clean up text - remove extra whitespace and empty lines
423
  lines = [line.strip() for line in text.splitlines()]
424
  text = '\n'.join(line for line in lines if line)
425
 
426
  if not text or len(text) < 50:
427
  return f"Error: Scraped content was too short or empty (length: {len(text)})."
428
 
429
- print(f"Scraped text length: {len(text)} characters")
430
-
431
- # Split text into chunks
432
  chunks = global_text_splitter.split_text(text)
433
 
434
  if not chunks:
435
  return "Error: Text could not be split into chunks."
436
 
437
- print(f"Created {len(chunks)} chunks")
438
-
439
- # Create Document objects
440
  docs = [Document(page_content=chunk, metadata={"source": url}) for chunk in chunks]
441
 
442
- # Create FAISS vector store
443
- print("Creating embeddings and vector store...")
444
  db = FAISS.from_documents(docs, global_embeddings)
445
 
446
- # Retrieve relevant chunks
447
- print(f"Searching for: '{query}'")
448
  retriever = db.as_retriever(search_kwargs={"k": 5})
449
  retrieved_docs = retriever.invoke(query)
450
 
451
  if not retrieved_docs:
452
  return f"No relevant information found on {url} for query: '{query}'\n\nThe page was successfully scraped but doesn't seem to contain information matching your query."
453
 
454
- print(f"Retrieved {len(retrieved_docs)} relevant chunks")
455
-
456
- # Combine retrieved chunks
457
  context_parts = []
458
  for i, doc in enumerate(retrieved_docs, 1):
459
  context_parts.append(f"[Chunk {i}]\n{doc.page_content}")
460
 
461
  context = "\n\n---\n\n".join(context_parts)
462
 
463
- result = f"Successfully retrieved relevant information from {url}\n\nQuery: {query}\n\n{context}"
464
 
465
  return truncate_if_needed(result)
466
 
@@ -472,13 +641,24 @@ def scrape_and_retrieve(url: str, query: str) -> str:
472
 
473
 
474
  class FinalAnswerInput(BaseModel):
475
- answer: str = Field(description="The final, definitive answer to the question.")
476
 
477
  @tool(args_schema=FinalAnswerInput)
478
  def final_answer_tool(answer: str) -> str:
479
  """
480
- Call this tool ONLY when you have the final, definitive answer.
481
- The 'answer' must be EXACTLY what was asked for, with no extra text.
 
 
 
 
 
 
 
 
 
 
 
482
  """
483
  if not isinstance(answer, str):
484
  try:
@@ -486,24 +666,60 @@ def final_answer_tool(answer: str) -> str:
486
  except:
487
  return "Error: Invalid input. 'answer' must be a string."
488
 
489
- print(f"--- FINAL ANSWER TOOL CALLED ---")
490
- print(f"Answer: {answer}")
491
  return answer
492
 
493
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
494
  # =============================================================================
495
  # FALLBACK PARSER
496
  # =============================================================================
497
  def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]:
498
- """
499
- Parses malformed tool call strings from an LLM response.
500
- """
501
- print(f"Original LLM content for fallback parsing:\n---\n{content[:500]}\n---")
502
  tool_name = None
503
  tool_input = None
504
  cleaned_str = None
505
 
506
- # STRATEGY 1: Try to parse <function(tool_name)>...{json_string}...
507
  func_match = re.search(
508
  r"<function[(=]\s*([^)]+)\s*[)>](.*)",
509
  content,
@@ -523,28 +739,26 @@ def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]:
523
  cleaned_str = cleaned_str.strip().rstrip(',')
524
 
525
  tool_input = json.loads(cleaned_str)
526
- print(f"🔧 Fallback (Format 1 - json.loads): Parsed tool call for '{tool_name}'")
527
  else:
528
- print(f"⚠️ Fallback (Format 1): Found <function> but no JSON blob.")
529
  tool_name = None
530
 
531
  except json.JSONDecodeError as e:
532
- print(f"⚠️ Fallback (Format 1): json.loads failed: {e}. Trying ast.literal_eval.")
533
  try:
534
  if cleaned_str:
535
  potential_input = ast.literal_eval(cleaned_str)
536
  if isinstance(potential_input, dict):
537
  tool_input = potential_input
538
- print(f"🔧 Fallback (Format 1 - ast.literal_eval): Parsed tool call for '{tool_name}'")
539
  else:
540
- print(f"⚠️ Fallback (Format 1): ast.literal_eval did not produce a dict.")
541
  tool_name = None
542
  else:
543
  tool_name = None
544
  except:
545
  tool_name = None
546
 
547
- # FINAL VALIDATION
548
  if tool_name and tool_input is not None:
549
  if any(t.name == tool_name for t in tools):
550
  tool_call = ToolCall(
@@ -556,79 +770,52 @@ def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]:
556
  return [tool_call]
557
  else:
558
  print(f"❌ Tool '{tool_name}' not found in available tools")
559
- print(f" Available: {[t.name for t in tools]}")
560
 
561
  print("❌ Failed to parse any valid tool call from content")
562
  return []
563
 
564
 
565
- # =============================================================================
566
- # DEFINED TOOLS LIST
567
- # =============================================================================
568
- defined_tools = [
569
- search_tool,
570
- code_interpreter,
571
- read_file,
572
- write_file,
573
- list_directory,
574
- audio_transcription_tool,
575
- get_youtube_transcript,
576
- scrape_and_retrieve,
577
- final_answer_tool
578
- ]
579
-
580
-
581
- # =============================================================================
582
- # AGENT STATE
583
- # =============================================================================
584
- class AgentState(TypedDict):
585
- messages: Annotated[List[AnyMessage], add_messages]
586
- turn: int
587
-
588
-
589
  # =============================================================================
590
  # CONDITIONAL EDGE FUNCTION
591
  # =============================================================================
592
  def should_continue(state: AgentState):
593
- """
594
- Decide whether to continue, call tools, or end.
595
- """
596
  last_message = state['messages'][-1]
597
  current_turn = state.get('turn', 0)
598
 
599
- # 1. Check for final_answer_tool
600
  if isinstance(last_message, AIMessage) and last_message.tool_calls:
601
  for tool_call in last_message.tool_calls:
602
  if tool_call.get("name") == "final_answer_tool":
603
  print("--- Condition: final_answer_tool called, ending. ---")
604
  return END
605
 
606
- # 2. Check turn limit
607
  if current_turn >= MAX_TURNS:
608
  print(f"--- Condition: Max turns ({MAX_TURNS}) reached. Ending. ---")
609
  return END
610
 
611
- # 3. Route to tools if tool calls exist
612
  if isinstance(last_message, AIMessage) and last_message.tool_calls:
613
  print("--- Condition: Tools called, routing to tools node. ---")
614
  return "tools"
615
 
616
- # 4. Loop prevention
617
  if len(state['messages']) > 2 and isinstance(last_message, AIMessage) and isinstance(state['messages'][-2], AIMessage):
618
  print(f"--- Condition: Detected 2+ consecutive AI messages (Turn {current_turn}). Ending to prevent loop. ---")
619
  return END
620
 
621
- # 5. Loop back to agent
622
  print(f"--- Condition: No tool call (Turn {current_turn}). Continuing to agent. ---")
623
  return "agent"
624
 
625
 
626
  # =============================================================================
627
- # BASIC AGENT CLASS
628
  # =============================================================================
629
- class BasicAgent:
630
  def __init__(self):
631
- print("BasicAgent (Single LLM) initializing...")
632
 
633
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
634
  if not GROQ_API_KEY:
@@ -638,7 +825,7 @@ class BasicAgent:
638
 
639
  # Initialize RAG Components
640
  if not initialize_rag_components():
641
- print("⚠️ Warning: RAG components failed to initialize. scrape_and_retrieve may not work.")
642
 
643
  # Build tool descriptions
644
  tool_desc_list = []
@@ -656,31 +843,104 @@ class BasicAgent:
656
  tool_desc_list.append(desc)
657
  tool_descriptions = "\n".join(tool_desc_list)
658
 
659
- # System Prompt
660
- self.system_prompt = f"""You are a highly intelligent AI assistant for the GAIA benchmark.
661
- Your goal: Provide the EXACT answer in the EXACT format requested.
662
-
663
- **PROTOCOL:**
664
- 1. **ANALYZE:** Read the question and history. What is the next logical step?
665
- 2. **ACT:** Call ONE tool to get information or perform a calculation.
666
- 3. **EVALUATE:** Look at the tool's output. Do you have the final answer?
667
- - **If NO:** Go back to Step 1 and decide the *next* step.
668
- - **If YES:** Call final_answer_tool immediately with the answer.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
669
 
670
- **CRITICAL RULES:**
671
- - **TOOL USE:** You MUST use tools to find the answer. Do NOT use your own knowledge.
672
- - **FINAL ANSWER:** When you have the answer, use final_answer_tool. The 'answer' argument must be the answer ONLY (e.g., "42", "red, blue, green").
673
- - **NO CONVERSATIONAL TEXT:** Never add phrases like "The answer is" or "Based on the information". Just the answer.
674
-
675
- **TOOLS:**
676
  {tool_descriptions}
677
 
678
- **REMEMBER:** One step at a time. Use tools. Call final_answer_tool when done.
 
 
679
  """
680
 
681
  print("Initializing Groq LLM...")
682
  try:
683
- # Changed from tool_choice="any" to "auto" for better flexibility
684
  self.llm_with_tools = ChatGroq(
685
  temperature=0,
686
  groq_api_key=GROQ_API_KEY,
@@ -688,27 +948,51 @@ Your goal: Provide the EXACT answer in the EXACT format requested.
688
  max_tokens=4096,
689
  timeout=60
690
  ).bind_tools(self.tools, tool_choice="auto")
691
- print("✅ Main LLM (llama-3.3-70b-versatile with tools) initialized.")
692
 
693
  except Exception as e:
694
  print(f"❌ Error initializing Groq: {e}")
695
  raise
696
 
697
- # Agent Node
698
  def agent_node(state: AgentState):
699
  current_turn = state.get('turn', 0) + 1
700
- print(f"\n{'='*60}")
701
- print(f"AGENT TURN {current_turn}/{MAX_TURNS}")
702
- print('='*60)
703
 
704
  if current_turn > MAX_TURNS:
705
- return {"messages": [SystemMessage(content="Max turns reached.")], "turn": current_turn}
706
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
707
  max_retries = 3
708
  ai_message = None
709
  for attempt in range(max_retries):
710
  try:
711
- ai_message = self.llm_with_tools.invoke(state["messages"])
712
  break
713
  except Exception as e:
714
  print(f"⚠️ LLM attempt {attempt+1}/{max_retries} failed: {e}")
@@ -718,32 +1002,58 @@ Your goal: Provide the EXACT answer in the EXACT format requested.
718
  )
719
  time.sleep(2 ** attempt)
720
 
721
- # Fallback Parsing Logic
722
  if not ai_message.tool_calls and isinstance(ai_message.content, str) and ai_message.content.strip():
723
  parsed_tool_calls = parse_tool_call_from_string(ai_message.content, self.tools)
724
  if parsed_tool_calls:
725
- print("🔧 Fallback SUCCESS: Rebuilding tool call(s).")
726
  ai_message.tool_calls = parsed_tool_calls
727
  ai_message.content = ""
728
- else:
729
- print(f"⚠️ Fallback FAILED: Could not parse any tool call from content:\n{ai_message.content[:200]}...")
 
 
730
 
731
  if ai_message.tool_calls:
732
- print(f"🔧 Agent Tool Call: {ai_message.tool_calls[0]['name']}")
 
 
 
 
 
733
  else:
734
- print(f"💭 Agent Reasoning: {ai_message.content[:200]}...")
735
 
736
- return {"messages": [ai_message], "turn": current_turn}
737
-
738
- # Tool Node
739
- tool_node = ToolNode(self.tools)
 
 
740
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741
  # Build Graph
742
- print("Building Single-Agent graph...")
743
  graph_builder = StateGraph(AgentState)
744
 
745
  graph_builder.add_node("agent", agent_node)
746
- graph_builder.add_node("tools", tool_node)
747
 
748
  graph_builder.add_edge(START, "agent")
749
 
@@ -760,87 +1070,119 @@ Your goal: Provide the EXACT answer in the EXACT format requested.
760
  graph_builder.add_edge("tools", "agent")
761
 
762
  self.graph = graph_builder.compile()
763
- print("✅ Single-Agent graph compiled successfully.")
764
 
765
  def __call__(self, question: str) -> str:
766
- print(f"\n--- Starting Agent Run for Question ---")
767
- print(f"Agent received question (first 100 chars): {question[:100]}...")
 
 
 
768
 
769
  graph_input = {
770
  "messages": [
771
  SystemMessage(content=self.system_prompt),
772
  HumanMessage(content=question)
773
  ],
774
- "turn": 0
 
 
 
775
  }
776
 
777
  final_answer = "AGENT FAILED TO PRODUCE ANSWER"
778
  try:
779
- config = {"recursion_limit": MAX_TURNS + 5}
780
  for event in self.graph.stream(graph_input, stream_mode="values", config=config):
781
 
782
- if event.get('messages'): # Ensure messages exist
783
- last_message = event["messages"][-1]
784
- else:
785
- continue # Skip if no messages yet
786
 
787
  # Check for final answer extraction
788
  if isinstance(last_message, AIMessage) and last_message.tool_calls:
789
  if last_message.tool_calls[0].get("name") == "final_answer_tool":
790
  final_answer_args = last_message.tool_calls[0].get('args', {})
791
  if 'answer' in final_answer_args:
792
- final_answer = final_answer_args['answer']
793
- print(f"--- Final Answer Captured from tool call: '{final_answer}' ---")
794
- break
 
 
795
  else:
796
- print(f"⚠️ Final Answer tool called without 'answer' argument: {final_answer_args}")
797
- final_answer = "ERROR: FINAL_ANSWER_TOOL CALLED WITHOUT ANSWER"
798
- break
799
 
800
  elif isinstance(last_message, ToolMessage):
801
- print(f"Tool Result ({last_message.tool_call_id}): {last_message.content[:500]}...")
 
802
  elif isinstance(last_message, AIMessage) and not last_message.tool_calls:
803
- print(f"AI Message (Reasoning): {last_message.content[:500]}...")
804
- elif isinstance(last_message, SystemMessage):
805
- print(f"System Message: {last_message.content[:500]}...")
806
-
807
 
808
- # --- Final Answer Cleaning ---
809
  cleaned_answer = str(final_answer).strip()
810
- prefixes_to_remove = ["The answer is:", "Here is the answer:", "Based on the information:", "Final Answer:", "Answer:"]
811
- original_cleaned = cleaned_answer
 
 
 
 
 
812
  for prefix in prefixes_to_remove:
813
  if cleaned_answer.lower().startswith(prefix.lower()):
814
  potential_answer = cleaned_answer[len(prefix):].strip()
815
- if potential_answer:
816
  cleaned_answer = potential_answer
817
- break
818
 
 
819
  cleaned_answer = remove_fences_simple(cleaned_answer)
820
- if cleaned_answer.startswith("`") and cleaned_answer.endswith("`"):
821
- cleaned_answer = cleaned_answer[1:-1].strip()
822
-
823
- print(f"Agent returning final answer (cleaned): '{cleaned_answer}'")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
  return cleaned_answer
825
 
826
  except Exception as e:
827
- print(f"Error running agent graph: {e}")
828
  tb_str = traceback.format_exc()
829
  print(tb_str)
830
  return f"AGENT GRAPH ERROR: {e}"
831
 
832
 
833
- # ====================================================
834
- # --- Global Agent Instantiation ---
835
-
836
  try:
837
- agent = BasicAgent()
838
- print("✅ Global BasicAgent instantiated successfully.")
839
- if asr_pipeline is None: print("⚠️ Global ASR Pipeline failed load.")
 
 
 
840
  except Exception as e:
841
  print(f"❌ FATAL: Could not instantiate global agent: {e}")
842
  traceback.print_exc()
843
  agent = None
 
844
 
845
  # ====================================================
846
  # --- (Original Template Code - Mock Questions Version) ---
 
 
1
  import os
2
  import io
3
  import json
 
7
  import uuid
8
  import time
9
  import ast
10
+ from typing import List, Optional, TypedDict, Annotated, Dict
11
  from pathlib import Path
12
+ from collections import Counter
13
 
 
14
  import pandas as pd
15
+ import numpy as np
16
  import torch
17
  from pydantic import BaseModel, Field
18
 
 
41
  # CONFIGURATION
42
  # =============================================================================
43
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
44
+ MAX_TURNS = 25 # Increased for planning/reflection
45
  MAX_MESSAGE_LENGTH = 8000
46
+ REFLECT_EVERY_N_TURNS = 5
47
 
48
  # =============================================================================
49
+ # GLOBAL RAG COMPONENTS
50
  # =============================================================================
51
  global_embeddings = None
52
  global_text_splitter = None
 
139
  return None
140
 
141
  # =============================================================================
142
+ # PLANNING & REFLECTION TOOLS
143
+ # =============================================================================
144
+
145
+ class PlanInput(BaseModel):
146
+ question: str = Field(description="The question to create a plan for")
147
+
148
+ @tool(args_schema=PlanInput)
149
+ def create_plan(question: str) -> str:
150
+ """
151
+ Creates a step-by-step plan for answering a question.
152
+ CRITICAL: Call this FIRST for any multi-step or complex question.
153
+
154
+ This helps you think through:
155
+ 1. What information do you need?
156
+ 2. In what order should you gather it?
157
+ 3. What tools will you use?
158
+
159
+ After calling this, execute the plan step-by-step.
160
+ """
161
+ print(f"📋 Planning phase initiated for: {question[:100]}...")
162
+
163
+ return f"""✅ Plan Created. Now execute these steps methodically:
164
+
165
+ PLANNING FRAMEWORK:
166
+ 1. GOAL: What exact answer format is needed?
167
+ 2. REQUIREMENTS: What data/information is required?
168
+ 3. STRATEGY: What's the most efficient path?
169
+ 4. EXECUTION: List concrete actions in order
170
+
171
+ Now proceed with Step 1 of your plan."""
172
+
173
+
174
+ class ReflectInput(BaseModel):
175
+ current_situation: str = Field(description="Brief summary of what you've tried and where you are stuck")
176
+
177
+ @tool(args_schema=ReflectInput)
178
+ def reflect_on_progress(current_situation: str) -> str:
179
+ """
180
+ Reflects on your progress and suggests what to do next.
181
+
182
+ Call this when:
183
+ - You feel stuck or uncertain
184
+ - Tools keep failing
185
+ - You're not making progress
186
+ - You've taken 5+ steps without getting closer to the answer
187
+
188
+ This helps you step back and reconsider your approach.
189
+ """
190
+ print(f"🤔 Reflection initiated: {current_situation[:100]}...")
191
+
192
+ return f"""🔍 REFLECTION ANALYSIS:
193
+
194
+ Current situation: {current_situation}
195
+
196
+ CRITICAL QUESTIONS TO ASK YOURSELF:
197
+ 1. Have I gathered the information I actually need?
198
+ 2. Am I using the right tools for this task?
199
+ 3. Am I going in circles (repeating similar actions)?
200
+ 4. Should I try a completely different approach?
201
+ 5. Do I have enough information to answer now?
202
+
203
+ NEXT STEPS:
204
+ - If stuck: Try a different tool or search query
205
+ - If missing info: Identify exactly what's missing
206
+ - If have info: Proceed to final_answer_tool
207
+ - If uncertain: Break problem into smaller pieces
208
+
209
+ Take a different approach now."""
210
+
211
+
212
+ class ValidateInput(BaseModel):
213
+ proposed_answer: str = Field(description="The answer you plan to submit")
214
+ original_question: str = Field(description="The original question")
215
+
216
+ @tool(args_schema=ValidateInput)
217
+ def validate_answer(proposed_answer: str, original_question: str) -> str:
218
+ """
219
+ Validates your proposed answer before submission.
220
+ CRITICAL: ALWAYS call this before final_answer_tool.
221
+
222
+ Checks:
223
+ - Does the answer match what was asked?
224
+ - Is it in the correct format?
225
+ - Are there any obvious issues?
226
+
227
+ If validation passes, then call final_answer_tool.
228
+ If validation fails, gather more information or correct the format.
229
+ """
230
+ print(f"✓ Validating answer: '{proposed_answer[:50]}...'")
231
+
232
+ issues = []
233
+ warnings = []
234
+
235
+ # Check for conversational fluff
236
+ fluff_phrases = ["the answer is", "based on", "according to", "i found that", "here is", "final answer"]
237
+ if any(phrase in proposed_answer.lower() for phrase in fluff_phrases):
238
+ issues.append("❌ Remove conversational text. Provide ONLY the answer.")
239
+
240
+ # Check for number format if question asks for numbers
241
+ number_keywords = ["how many", "what number", "count", "total", "sum"]
242
+ if any(kw in original_question.lower() for kw in number_keywords):
243
+ if not any(char.isdigit() for char in proposed_answer):
244
+ warnings.append("⚠️ Question seems to ask for a number, but answer contains no digits.")
245
+
246
+ # Check for list format
247
+ if "list" in original_question.lower() and "," not in proposed_answer:
248
+ warnings.append("⚠️ Question asks for a list, consider comma-separated format.")
249
+
250
+ # Check for yes/no questions
251
+ if original_question.lower().strip().startswith(("is ", "are ", "was ", "were ", "do ", "does ", "did ", "can ", "will ")):
252
+ if proposed_answer.lower() not in ["yes", "no", "true", "false"]:
253
+ warnings.append("⚠️ This looks like a yes/no question. Consider simple yes/no answer.")
254
+
255
+ # Check for code fences or markdown
256
+ if "```" in proposed_answer:
257
+ issues.append("❌ Remove code fences (```) from the answer.")
258
+
259
+ # Check length
260
+ if len(proposed_answer) > 500:
261
+ warnings.append("⚠️ Answer is quite long. Are you sure this is just the answer and not an explanation?")
262
+
263
+ if issues:
264
+ return "🚫 VALIDATION FAILED:\n" + "\n".join(issues) + "\n\nFix these issues before calling final_answer_tool."
265
+
266
+ if warnings:
267
+ return "⚠️ VALIDATION WARNINGS:\n" + "\n".join(warnings) + "\n\nConsider these points, but you may proceed if confident."
268
+
269
+ return "✅ VALIDATION PASSED: Answer looks good! Proceed with final_answer_tool now."
270
+
271
+
272
+ # =============================================================================
273
+ # CORE TOOLS
274
  # =============================================================================
275
 
276
  class SearchInput(BaseModel):
 
278
 
279
  @tool(args_schema=SearchInput)
280
  def search_tool(query: str) -> str:
281
+ """
282
+ Searches the web using DuckDuckGo.
283
+ Use for: recent information, facts, general web searches.
284
+
285
+ Tips:
286
+ - Keep queries concise and specific
287
+ - Include year for time-sensitive queries (e.g., "GDP Brazil 2016")
288
+ - Try different phrasings if first search doesn't help
289
+ """
290
  if not isinstance(query, str) or not query.strip():
291
  return "Error: Invalid input. 'query' must be a non-empty string."
292
 
293
+ print(f"🔍 Searching: {query}")
294
  try:
295
  search = DuckDuckGoSearchRun()
296
  result = search.run(query)
 
301
  return f"Error running search for '{query}': {str(e)}"
302
 
303
 
304
+ class CalcInput(BaseModel):
305
+ expression: str = Field(description="Mathematical expression to evaluate (e.g., '2 + 2', 'sqrt(16)', '45 * 1.2')")
306
+
307
+ @tool(args_schema=CalcInput)
308
+ def calculator(expression: str) -> str:
309
+ """
310
+ Evaluates mathematical expressions.
311
+ Use this for ANY calculations instead of code_interpreter.
312
+
313
+ Supports: +, -, *, /, **, sqrt, sin, cos, tan, log, exp, pi, e, abs, round
314
+
315
+ Examples:
316
+ - calculator("127 * 83")
317
+ - calculator("sqrt(144)")
318
+ - calculator("(45 + 23) / 2")
319
+ """
320
+ if not isinstance(expression, str) or not expression.strip():
321
+ return "Error: Invalid expression."
322
+
323
+ print(f"🧮 Calculating: {expression}")
324
+
325
+ try:
326
+ # Create safe namespace with math functions
327
+ import math
328
+ safe_dict = {
329
+ 'sqrt': math.sqrt, 'sin': math.sin, 'cos': math.cos, 'tan': math.tan,
330
+ 'log': math.log, 'log10': math.log10, 'exp': math.exp,
331
+ 'pi': math.pi, 'e': math.e, 'abs': abs, 'round': round,
332
+ 'pow': pow, 'sum': sum, 'min': min, 'max': max
333
+ }
334
+
335
+ result = eval(expression, {"__builtins__": {}}, safe_dict)
336
+ return f"{result}"
337
+ except Exception as e:
338
+ return f"Error evaluating '{expression}': {str(e)}\nMake sure to use proper syntax (e.g., sqrt(16), not sqrt 16)"
339
+
340
+
341
  class CodeInput(BaseModel):
342
+ code: str = Field(description="Python code to execute. MUST include print() for output.")
343
 
344
  @tool(args_schema=CodeInput)
345
  def code_interpreter(code: str) -> str:
346
  """
347
+ Executes Python code for complex data processing.
348
+
349
+ WHEN TO USE:
350
+ - Data analysis (CSV, Excel files)
351
+ - Complex calculations with loops/conditionals
352
+ - String manipulation
353
+ - Date/time calculations
354
+
355
+ WHEN NOT TO USE:
356
+ - Simple math (use calculator instead)
357
+ - Web searches (use search_tool)
358
+
359
+ Available libraries: pandas as pd, numpy as np, json, re, datetime
360
+
361
+ CRITICAL: Always use print() to output results!
362
  """
363
  if not isinstance(code, str):
364
  return "Error: Invalid input. 'code' must be a string."
365
 
366
+ # Safety checks
367
+ dangerous_patterns = ['__import__', 'eval(', 'compile(', 'subprocess', 'os.system', 'exec(']
368
  code_lower = code.lower()
369
  for pattern in dangerous_patterns:
370
  if pattern in code_lower:
 
373
  if 'open(' in code_lower and any(mode in code for mode in ["'w'", '"w"', "'a'", '"a"', "'wb'", '"wb"']):
374
  return "Error: Writing files is not allowed in code_interpreter. Use write_file tool instead."
375
 
376
+ print(f"💻 Executing code...")
377
  output_stream = io.StringIO()
378
  error_stream = io.StringIO()
379
 
 
381
  with contextlib.redirect_stdout(output_stream), contextlib.redirect_stderr(error_stream):
382
  safe_globals = {
383
  "pd": pd,
384
+ "np": np,
385
+ "json": json,
386
+ "re": re,
387
  "__builtins__": __builtins__
388
  }
389
  exec(code, safe_globals, {})
 
397
  if stdout:
398
  if len(stdout) > MAX_MESSAGE_LENGTH:
399
  stdout = stdout[:MAX_MESSAGE_LENGTH] + f"\n...[truncated, {len(stdout)} total chars]"
400
+ return f"{stdout}"
401
 
402
+ return "Code executed but produced no output. Remember to use print() to display results!"
403
 
404
  except Exception as e:
405
  tb_str = traceback.format_exc()
 
407
 
408
 
409
  class ReadFileInput(BaseModel):
410
+ path: str = Field(description="Path to the file to read")
411
 
412
  @tool(args_schema=ReadFileInput)
413
  def read_file(path: str) -> str:
414
+ """Reads a file from the filesystem."""
415
  if not isinstance(path, str) or not path.strip():
416
  return "Error: Invalid input. 'path' must be a non-empty string."
417
 
418
+ print(f"📄 Reading file: {path}")
419
 
420
  file_path = find_file(path)
421
  if not file_path:
 
437
 
438
 
439
  class WriteFileInput(BaseModel):
440
+ path: str = Field(description="Path where file should be written")
441
+ content: str = Field(description="Content to write to the file")
442
 
443
  @tool(args_schema=WriteFileInput)
444
  def write_file(path: str, content: str) -> str:
445
+ """Writes content to a file."""
446
  if not isinstance(path, str) or not path.strip():
447
  return "Error: Invalid input. 'path' must be a non-empty string."
448
  if not isinstance(content, str):
449
  return "Error: Invalid input. 'content' must be a string."
450
 
451
+ print(f"✍️ Writing file: {path}")
452
 
453
  try:
454
  file_path = Path.cwd() / path
 
460
 
461
 
462
  class ListDirInput(BaseModel):
463
+ path: str = Field(description="Directory path to list", default=".")
464
 
465
  @tool(args_schema=ListDirInput)
466
  def list_directory(path: str = ".") -> str:
467
+ """Lists files and directories in a path."""
468
+ print(f"📁 Listing directory: {path}")
469
 
470
  try:
471
  dir_path = Path.cwd() / path if path != "." else Path.cwd()
 
499
 
500
 
501
  class AudioInput(BaseModel):
502
+ file_path: str = Field(description="Path to audio file to transcribe")
503
 
504
  @tool(args_schema=AudioInput)
505
  def audio_transcription_tool(file_path: str) -> str:
506
+ """Transcribes audio files to text using Whisper."""
507
  if not isinstance(file_path, str) or not file_path.strip():
508
  return "Error: Invalid input. 'file_path' must be a non-empty string."
509
 
510
+ print(f"🎤 Transcribing audio: {file_path}")
511
 
512
  if asr_pipeline is None:
513
  return "Error: ASR pipeline is not available."
 
527
  except Exception as e:
528
  return f"Error transcribing '{file_path}': {str(e)}"
529
 
 
530
  class YoutubeInput(BaseModel):
531
+ video_url: str = Field(description="YouTube video URL")
532
 
533
  @tool(args_schema=YoutubeInput)
534
  def get_youtube_transcript(video_url: str) -> str:
535
+ """Fetches transcript/captions from a YouTube video."""
536
  if not isinstance(video_url, str) or not video_url.strip():
537
  return "Error: Invalid input. 'video_url' must be a non-empty string."
538
 
539
+ print(f"📺 Getting YouTube transcript: {video_url}")
540
 
541
  try:
542
  video_id = None
 
560
 
561
 
562
  class ScrapeInput(BaseModel):
563
+ url: str = Field(description="URL to scrape (must start with http:// or https://)")
564
+ query: str = Field(description="Specific question or information to find on the page")
565
 
566
  @tool(args_schema=ScrapeInput)
567
  def scrape_and_retrieve(url: str, query: str) -> str:
568
  """
569
+ Scrapes a webpage and uses RAG to find relevant information.
570
+
571
+ Use when:
572
+ - You need specific information from a known webpage
573
+ - Search results give you a URL that contains the answer
574
+ - You need to extract data from a specific website
575
  """
576
  if not (url.lower().startswith(('http://', 'https://'))):
577
  return f"Error: Invalid URL. Must start with http:// or https://. Got: '{url}'"
578
  if not query or not query.strip():
579
  return "Error: A query is required to search the page content."
580
 
 
581
  if global_embeddings is None or global_text_splitter is None:
582
  if not initialize_rag_components():
583
  return "Error: RAG components could not be initialized."
584
 
585
+ print(f"🌐 Scraping & retrieving from: {url}")
586
 
587
  try:
 
588
  headers = {
589
  'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
590
  }
 
591
  response = requests.get(url, headers=headers, timeout=20)
592
  response.raise_for_status()
593
 
 
594
  soup = BeautifulSoup(response.text, 'html.parser')
595
 
 
596
  for tag in soup(["script", "style", "nav", "footer", "aside", "header", "iframe", "noscript"]):
597
  tag.extract()
598
 
 
599
  main_content = soup.find('main') or soup.find('article') or soup.find('div', class_=re.compile('content|main|article', re.I)) or soup.body
600
 
601
  if not main_content:
602
  return "Error: Could not find main content on the page."
603
 
 
604
  text = main_content.get_text(separator='\n', strip=True)
 
 
605
  lines = [line.strip() for line in text.splitlines()]
606
  text = '\n'.join(line for line in lines if line)
607
 
608
  if not text or len(text) < 50:
609
  return f"Error: Scraped content was too short or empty (length: {len(text)})."
610
 
 
 
 
611
  chunks = global_text_splitter.split_text(text)
612
 
613
  if not chunks:
614
  return "Error: Text could not be split into chunks."
615
 
 
 
 
616
  docs = [Document(page_content=chunk, metadata={"source": url}) for chunk in chunks]
617
 
 
 
618
  db = FAISS.from_documents(docs, global_embeddings)
619
 
 
 
620
  retriever = db.as_retriever(search_kwargs={"k": 5})
621
  retrieved_docs = retriever.invoke(query)
622
 
623
  if not retrieved_docs:
624
  return f"No relevant information found on {url} for query: '{query}'\n\nThe page was successfully scraped but doesn't seem to contain information matching your query."
625
 
 
 
 
626
  context_parts = []
627
  for i, doc in enumerate(retrieved_docs, 1):
628
  context_parts.append(f"[Chunk {i}]\n{doc.page_content}")
629
 
630
  context = "\n\n---\n\n".join(context_parts)
631
 
632
+ result = f"Relevant information from {url}:\n\n{context}"
633
 
634
  return truncate_if_needed(result)
635
 
 
641
 
642
 
643
  class FinalAnswerInput(BaseModel):
644
+ answer: str = Field(description="The final answer - EXACTLY what was asked for, nothing more")
645
 
646
  @tool(args_schema=FinalAnswerInput)
647
  def final_answer_tool(answer: str) -> str:
648
  """
649
+ Submit your final answer.
650
+
651
+ CRITICAL RULES:
652
+ 1. ALWAYS call validate_answer() before this
653
+ 2. The answer must be EXACTLY what was asked for
654
+ 3. NO conversational text (no "The answer is...", etc.)
655
+ 4. NO explanations
656
+ 5. Match the requested format exactly
657
+
658
+ Examples:
659
+ - If asked for a number: "42" (not "The answer is 42")
660
+ - If asked for a list: "red, blue, green" (not "The colors are: red, blue, green")
661
+ - If asked yes/no: "yes" (not "Yes, it is true")
662
  """
663
  if not isinstance(answer, str):
664
  try:
 
666
  except:
667
  return "Error: Invalid input. 'answer' must be a string."
668
 
669
+ print(f" FINAL ANSWER SUBMITTED: {answer}")
 
670
  return answer
671
 
672
 
673
+ # =============================================================================
674
+ # DEFINED TOOLS LIST
675
+ # =============================================================================
676
+ defined_tools = [
677
+ # Planning & Reflection (use these first!)
678
+ create_plan,
679
+ reflect_on_progress,
680
+ validate_answer,
681
+
682
+ # Core tools
683
+ search_tool,
684
+ calculator,
685
+ code_interpreter,
686
+
687
+ # File operations
688
+ read_file,
689
+ write_file,
690
+ list_directory,
691
+
692
+ # Specialized tools
693
+ audio_transcription_tool,
694
+ get_youtube_transcript,
695
+ scrape_and_retrieve,
696
+
697
+ # Final answer
698
+ final_answer_tool
699
+ ]
700
+
701
+
702
+ # =============================================================================
703
+ # AGENT STATE
704
+ # =============================================================================
705
+ class AgentState(TypedDict):
706
+ messages: Annotated[List[AnyMessage], add_messages]
707
+ turn: int
708
+ has_plan: bool
709
+ consecutive_errors: int
710
+ tool_history: List[str]
711
+
712
+
713
  # =============================================================================
714
  # FALLBACK PARSER
715
  # =============================================================================
716
  def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]:
717
+ """Parses malformed tool call strings from an LLM response."""
718
+ print(f"Fallback parsing LLM content (first 500 chars):\n{content[:500]}")
 
 
719
  tool_name = None
720
  tool_input = None
721
  cleaned_str = None
722
 
 
723
  func_match = re.search(
724
  r"<function[(=]\s*([^)]+)\s*[)>](.*)",
725
  content,
 
739
  cleaned_str = cleaned_str.strip().rstrip(',')
740
 
741
  tool_input = json.loads(cleaned_str)
742
+ print(f"🔧 Fallback: Parsed tool call for '{tool_name}'")
743
  else:
744
+ print(f"⚠️ Fallback: Found <function> but no JSON blob.")
745
  tool_name = None
746
 
747
  except json.JSONDecodeError as e:
748
+ print(f"⚠️ Fallback: json.loads failed, trying ast.literal_eval.")
749
  try:
750
  if cleaned_str:
751
  potential_input = ast.literal_eval(cleaned_str)
752
  if isinstance(potential_input, dict):
753
  tool_input = potential_input
754
+ print(f"🔧 Fallback: Parsed with ast.literal_eval for '{tool_name}'")
755
  else:
 
756
  tool_name = None
757
  else:
758
  tool_name = None
759
  except:
760
  tool_name = None
761
 
 
762
  if tool_name and tool_input is not None:
763
  if any(t.name == tool_name for t in tools):
764
  tool_call = ToolCall(
 
770
  return [tool_call]
771
  else:
772
  print(f"❌ Tool '{tool_name}' not found in available tools")
 
773
 
774
  print("❌ Failed to parse any valid tool call from content")
775
  return []
776
 
777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778
  # =============================================================================
779
  # CONDITIONAL EDGE FUNCTION
780
  # =============================================================================
781
  def should_continue(state: AgentState):
782
+ """Decide whether to continue, call tools, or end."""
 
 
783
  last_message = state['messages'][-1]
784
  current_turn = state.get('turn', 0)
785
 
786
+ # Check for final_answer_tool
787
  if isinstance(last_message, AIMessage) and last_message.tool_calls:
788
  for tool_call in last_message.tool_calls:
789
  if tool_call.get("name") == "final_answer_tool":
790
  print("--- Condition: final_answer_tool called, ending. ---")
791
  return END
792
 
793
+ # Check turn limit
794
  if current_turn >= MAX_TURNS:
795
  print(f"--- Condition: Max turns ({MAX_TURNS}) reached. Ending. ---")
796
  return END
797
 
798
+ # Route to tools if tool calls exist
799
  if isinstance(last_message, AIMessage) and last_message.tool_calls:
800
  print("--- Condition: Tools called, routing to tools node. ---")
801
  return "tools"
802
 
803
+ # Loop prevention
804
  if len(state['messages']) > 2 and isinstance(last_message, AIMessage) and isinstance(state['messages'][-2], AIMessage):
805
  print(f"--- Condition: Detected 2+ consecutive AI messages (Turn {current_turn}). Ending to prevent loop. ---")
806
  return END
807
 
808
+ # Loop back to agent
809
  print(f"--- Condition: No tool call (Turn {current_turn}). Continuing to agent. ---")
810
  return "agent"
811
 
812
 
813
  # =============================================================================
814
+ # ENHANCED AGENT CLASS WITH PLANNING & REFLECTION
815
  # =============================================================================
816
+ class PlanningReflectionAgent:
817
  def __init__(self):
818
+ print("🧠 PlanningReflectionAgent initializing...")
819
 
820
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
821
  if not GROQ_API_KEY:
 
825
 
826
  # Initialize RAG Components
827
  if not initialize_rag_components():
828
+ print("⚠️ Warning: RAG components failed to initialize.")
829
 
830
  # Build tool descriptions
831
  tool_desc_list = []
 
843
  tool_desc_list.append(desc)
844
  tool_descriptions = "\n".join(tool_desc_list)
845
 
846
+ # Enhanced System Prompt with Planning & Reflection
847
+ self.system_prompt = f"""You are an elite AI agent designed for the GAIA benchmark - the most challenging question-answering tasks.
848
+
849
+ 🎯 YOUR MISSION: Provide the EXACT answer in the EXACT format requested.
850
+
851
+ ═══════════════════════════════════════════════════════════════
852
+ 📋 MANDATORY PROTOCOL - FOLLOW THIS RELIGIOUSLY:
853
+ ═══════════════════════════════════════════════════════════════
854
+
855
+ **PHASE 1: PLANNING (For complex/multi-step questions)**
856
+ ├─ 1. Call create_plan() to think through your approach
857
+ ├─ 2. Identify what information you need
858
+ └─ 3. Determine the sequence of steps
859
+
860
+ **PHASE 2: EXECUTION (One step at a time)**
861
+ ├─ 1. Take ONE action per turn
862
+ ├─ 2. Use the RIGHT tool for each task:
863
+ │ • Simple math → calculator()
864
+ │ • Complex data → code_interpreter()
865
+ │ • Web info → search_tool()
866
+ │ • Specific page → scrape_and_retrieve()
867
+ │ • Files → read_file()
868
+ ├─ 3. After EACH tool, evaluate the result
869
+ └─ 4. Ask: "Do I have enough to answer now?"
870
+
871
+ **PHASE 3: REFLECTION (If stuck)**
872
+ ├─ If no progress after 3-5 turns → call reflect_on_progress()
873
+ ├─ If tools keep failing → try different approach
874
+ └─ If going in circles → step back and reconsider
875
+
876
+ **PHASE 4: VALIDATION & SUBMISSION**
877
+ ├─ 1. When you have the answer → call validate_answer()
878
+ ├─ 2. If validation passes → call final_answer_tool()
879
+ └─ 3. If validation fails → fix the issue first
880
+
881
+ ═══════════════════════════════════════════════════════════════
882
+ 🎓 EXAMPLES - LEARN FROM THESE:
883
+ ═══════════════════════════════════════════════════════════════
884
+
885
+ **Example 1: Simple Math**
886
+ Q: What is 127 × 83?
887
+ Turn 1: calculator("127 * 83") → 10541
888
+ Turn 2: validate_answer("10541", "What is 127 × 83?") → ✅ Pass
889
+ Turn 3: final_answer_tool("10541")
890
+
891
+ **Example 2: Multi-step Research**
892
+ Q: What was the population of Einstein's birthplace in 1900?
893
+ Turn 1: create_plan("What was the population of Einstein's birthplace in 1900?")
894
+ Turn 2: search_tool("Albert Einstein birthplace") → Ulm, Germany
895
+ Turn 3: search_tool("Ulm Germany population 1900") → approximately 50,000
896
+ Turn 4: validate_answer("50000", "What was the population...") → ✅ Pass
897
+ Turn 5: final_answer_tool("50000")
898
+
899
+ **Example 3: File + Calculation**
900
+ Q: What's the average of the 'score' column in data.csv?
901
+ Turn 1: list_directory(".") → [files shown]
902
+ Turn 2: read_file("data.csv") → [content]
903
+ Turn 3: code_interpreter("import pandas as pd; df = pd.read_csv('data.csv'); print(df['score'].mean())")
904
+ → 78.5
905
+ Turn 4: validate_answer("78.5", "What's the average...") → ✅ Pass
906
+ Turn 5: final_answer_tool("78.5")
907
+
908
+ **Example 4: Getting Unstuck**
909
+ Q: What's the GDP of the 2016 Olympics host?
910
+ Turn 1: search_tool("2016 Olympics") → [general info, no clear answer]
911
+ Turn 2: search_tool("Olympics 2016 location") → [still unclear]
912
+ Turn 3: reflect_on_progress("Tried searching but not getting clear host country")
913
+ → Try: "2016 Summer Olympics host country"
914
+ Turn 4: search_tool("2016 Summer Olympics host country") → Brazil
915
+ Turn 5: search_tool("Brazil GDP 2016") → $1.796 trillion
916
+ Turn 6: validate_answer("1.796 trillion", original_q) → ✅ Pass
917
+ Turn 7: final_answer_tool("1.796 trillion")
918
+
919
+ ═══════════════════════════════════════════════════════════════
920
+ ⚠️ CRITICAL RULES - NEVER VIOLATE THESE:
921
+ ═══════════════════════════════════════════════════════════════
922
+
923
+ 1. **NO GUESSING**: Always use tools. Never use your own knowledge.
924
+ 2. **ONE STEP AT A TIME**: Don't try to do multiple things in one turn.
925
+ 3. **EXACT FORMAT**: Answer must be EXACTLY what was asked for.
926
+ 4. **NO FLUFF**: Never add "The answer is" or explanations in final answer.
927
+ 5. **ALWAYS VALIDATE**: Call validate_answer() before final_answer_tool().
928
+ 6. **PLAN COMPLEX TASKS**: Multi-step questions need create_plan() first.
929
+ 7. **REFLECT WHEN STUCK**: If no progress after 5 turns, call reflect_on_progress().
930
+
931
+ ═══════════════════════════════════════════════════════════════
932
+ 📚 AVAILABLE TOOLS:
933
+ ═══════════════════════════════════════════════════════════════
934
 
 
 
 
 
 
 
935
  {tool_descriptions}
936
 
937
+ ═══════════════════════════════════════════════════════════════
938
+ 🎯 REMEMBER: Quality over speed. Think carefully, plan ahead, execute methodically.
939
+ ═══════════════════════════════════════════════════════════════
940
  """
941
 
942
  print("Initializing Groq LLM...")
943
  try:
 
944
  self.llm_with_tools = ChatGroq(
945
  temperature=0,
946
  groq_api_key=GROQ_API_KEY,
 
948
  max_tokens=4096,
949
  timeout=60
950
  ).bind_tools(self.tools, tool_choice="auto")
951
+ print("✅ LLM initialized.")
952
 
953
  except Exception as e:
954
  print(f"❌ Error initializing Groq: {e}")
955
  raise
956
 
957
+ # Agent Node with Enhanced Logic
958
  def agent_node(state: AgentState):
959
  current_turn = state.get('turn', 0) + 1
960
+ print(f"\n{'='*70}")
961
+ print(f"🤖 AGENT TURN {current_turn}/{MAX_TURNS}")
962
+ print('='*70)
963
 
964
  if current_turn > MAX_TURNS:
965
+ return {
966
+ "messages": [SystemMessage(content="Max turns reached. Submitting best available answer.")],
967
+ "turn": current_turn
968
+ }
969
+
970
+ # Check if we should auto-trigger reflection
971
+ should_reflect = False
972
+ consecutive_errors = state.get('consecutive_errors', 0)
973
+
974
+ if current_turn > 5 and current_turn % REFLECT_EVERY_N_TURNS == 0:
975
+ should_reflect = True
976
+ print("🤔 Auto-triggering reflection (periodic check)")
977
+
978
+ if consecutive_errors >= 3:
979
+ should_reflect = True
980
+ print("🤔 Auto-triggering reflection (multiple errors)")
981
+
982
+ # Add reflection hint if needed
983
+ messages_to_send = state["messages"].copy()
984
+ if should_reflect and not state.get('has_plan', False):
985
+ hint = SystemMessage(
986
+ content="⚠️ SYSTEM HINT: You've been working for several turns. Consider calling reflect_on_progress() to evaluate your approach."
987
+ )
988
+ messages_to_send.append(hint)
989
+
990
+ # Invoke LLM
991
  max_retries = 3
992
  ai_message = None
993
  for attempt in range(max_retries):
994
  try:
995
+ ai_message = self.llm_with_tools.invoke(messages_to_send)
996
  break
997
  except Exception as e:
998
  print(f"⚠️ LLM attempt {attempt+1}/{max_retries} failed: {e}")
 
1002
  )
1003
  time.sleep(2 ** attempt)
1004
 
1005
+ # Fallback Parsing
1006
  if not ai_message.tool_calls and isinstance(ai_message.content, str) and ai_message.content.strip():
1007
  parsed_tool_calls = parse_tool_call_from_string(ai_message.content, self.tools)
1008
  if parsed_tool_calls:
1009
+ print("🔧 Fallback: Successfully rebuilt tool call")
1010
  ai_message.tool_calls = parsed_tool_calls
1011
  ai_message.content = ""
1012
+
1013
+ # Track tool usage
1014
+ tool_history = state.get('tool_history', [])
1015
+ has_plan = state.get('has_plan', False)
1016
 
1017
  if ai_message.tool_calls:
1018
+ tool_name = ai_message.tool_calls[0]['name']
1019
+ print(f"🔧 Tool Call: {tool_name}")
1020
+ tool_history.append(tool_name)
1021
+
1022
+ if tool_name == "create_plan":
1023
+ has_plan = True
1024
  else:
1025
+ print(f"💭 Reasoning: {ai_message.content[:200]}...")
1026
 
1027
+ return {
1028
+ "messages": [ai_message],
1029
+ "turn": current_turn,
1030
+ "has_plan": has_plan,
1031
+ "tool_history": tool_history
1032
+ }
1033
 
1034
+ # Tool Node with Error Tracking
1035
+ def tool_node_wrapper(state: AgentState):
1036
+ """Wraps tool execution to track errors"""
1037
+ tool_node = ToolNode(self.tools)
1038
+ result = tool_node(state)
1039
+
1040
+ # Check if last message is a tool error
1041
+ if result['messages']:
1042
+ last_msg = result['messages'][-1]
1043
+ if isinstance(last_msg, ToolMessage) and "Error" in last_msg.content:
1044
+ consecutive_errors = state.get('consecutive_errors', 0) + 1
1045
+ result['consecutive_errors'] = consecutive_errors
1046
+ else:
1047
+ result['consecutive_errors'] = 0
1048
+
1049
+ return result
1050
+
1051
  # Build Graph
1052
+ print("Building Planning & Reflection Agent graph...")
1053
  graph_builder = StateGraph(AgentState)
1054
 
1055
  graph_builder.add_node("agent", agent_node)
1056
+ graph_builder.add_node("tools", tool_node_wrapper)
1057
 
1058
  graph_builder.add_edge(START, "agent")
1059
 
 
1070
  graph_builder.add_edge("tools", "agent")
1071
 
1072
  self.graph = graph_builder.compile()
1073
+ print("✅ Planning & Reflection Agent graph compiled successfully.")
1074
 
1075
  def __call__(self, question: str) -> str:
1076
+ print(f"\n{'='*70}")
1077
+ print(f"🎯 NEW QUESTION")
1078
+ print(f"{'='*70}")
1079
+ print(f"Q: {question[:200]}{'...' if len(question) > 200 else ''}")
1080
+ print(f"{'='*70}\n")
1081
 
1082
  graph_input = {
1083
  "messages": [
1084
  SystemMessage(content=self.system_prompt),
1085
  HumanMessage(content=question)
1086
  ],
1087
+ "turn": 0,
1088
+ "has_plan": False,
1089
+ "consecutive_errors": 0,
1090
+ "tool_history": []
1091
  }
1092
 
1093
  final_answer = "AGENT FAILED TO PRODUCE ANSWER"
1094
  try:
1095
+ config = {"recursion_limit": MAX_TURNS + 10}
1096
  for event in self.graph.stream(graph_input, stream_mode="values", config=config):
1097
 
1098
+ if not event.get('messages'):
1099
+ continue
1100
+
1101
+ last_message = event["messages"][-1]
1102
 
1103
  # Check for final answer extraction
1104
  if isinstance(last_message, AIMessage) and last_message.tool_calls:
1105
  if last_message.tool_calls[0].get("name") == "final_answer_tool":
1106
  final_answer_args = last_message.tool_calls[0].get('args', {})
1107
  if 'answer' in final_answer_args:
1108
+ final_answer = final_answer_args['answer']
1109
+ print(f"\n{'='*70}")
1110
+ print(f"✅ FINAL ANSWER CAPTURED: '{final_answer}'")
1111
+ print(f"{'='*70}\n")
1112
+ break
1113
  else:
1114
+ print(f"⚠️ final_answer_tool called without 'answer' argument")
1115
+ final_answer = "ERROR: FINAL_ANSWER_TOOL CALLED WITHOUT ANSWER"
1116
+ break
1117
 
1118
  elif isinstance(last_message, ToolMessage):
1119
+ result_preview = last_message.content[:300].replace('\n', ' ')
1120
+ print(f"📊 Tool Result: {result_preview}...")
1121
  elif isinstance(last_message, AIMessage) and not last_message.tool_calls:
1122
+ print(f"💭 AI Reasoning: {last_message.content[:300]}...")
 
 
 
1123
 
1124
+ # Final Answer Cleaning
1125
  cleaned_answer = str(final_answer).strip()
1126
+
1127
+ # Remove common prefixes
1128
+ prefixes_to_remove = [
1129
+ "The answer is:", "Here is the answer:", "Based on the information:",
1130
+ "Final Answer:", "Answer:", "The final answer is:", "My answer is:",
1131
+ "According to", "I found that", "The result is:"
1132
+ ]
1133
  for prefix in prefixes_to_remove:
1134
  if cleaned_answer.lower().startswith(prefix.lower()):
1135
  potential_answer = cleaned_answer[len(prefix):].strip()
1136
+ if potential_answer:
1137
  cleaned_answer = potential_answer
1138
+ break
1139
 
1140
+ # Remove code fences
1141
  cleaned_answer = remove_fences_simple(cleaned_answer)
1142
+
1143
+ # Remove surrounding backticks
1144
+ while cleaned_answer.startswith("`") and cleaned_answer.endswith("`"):
1145
+ cleaned_answer = cleaned_answer[1:-1].strip()
1146
+
1147
+ # Remove quotes if they wrap the entire answer
1148
+ if (cleaned_answer.startswith('"') and cleaned_answer.endswith('"')) or \
1149
+ (cleaned_answer.startswith("'") and cleaned_answer.endswith("'")):
1150
+ cleaned_answer = cleaned_answer[1:-1].strip()
1151
+
1152
+ # Remove trailing periods for non-sentence answers
1153
+ if cleaned_answer.endswith('.') and len(cleaned_answer.split()) < 10:
1154
+ cleaned_answer = cleaned_answer[:-1]
1155
+
1156
+ print(f"\n{'='*70}")
1157
+ print(f"🎉 FINAL CLEANED ANSWER")
1158
+ print(f"{'='*70}")
1159
+ print(f"{cleaned_answer}")
1160
+ print(f"{'='*70}\n")
1161
+
1162
  return cleaned_answer
1163
 
1164
  except Exception as e:
1165
+ print(f"Error running agent graph: {e}")
1166
  tb_str = traceback.format_exc()
1167
  print(tb_str)
1168
  return f"AGENT GRAPH ERROR: {e}"
1169
 
1170
 
1171
+ # =============================================================================
1172
+ # GLOBAL AGENT INSTANTIATION
1173
+ # =============================================================================
1174
  try:
1175
+ initialize_rag_components()
1176
+
1177
+ agent = PlanningReflectionAgent()
1178
+ print("✅ Global PlanningReflectionAgent instantiated successfully.")
1179
+ if asr_pipeline is None:
1180
+ print("⚠️ Global ASR Pipeline failed to load.")
1181
  except Exception as e:
1182
  print(f"❌ FATAL: Could not instantiate global agent: {e}")
1183
  traceback.print_exc()
1184
  agent = None
1185
+
1186
 
1187
  # ====================================================
1188
  # --- (Original Template Code - Mock Questions Version) ---