gabejavitt commited on
Commit
b096a45
Β·
verified Β·
1 Parent(s): 3e572ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -89
app.py CHANGED
@@ -602,51 +602,51 @@ def validate_answer(proposed_answer: str, original_question: str = "") -> str:
602
  start_time = time.time()
603
 
604
  try:
605
- print(f"βœ“ Validating: '{answer[:50]}...'")
606
-
607
  warnings = []
608
  errors = []
609
  normalization_needed = []
610
-
611
  # Normalize for validation
612
- normalized = normalize_answer(answer)
613
-
614
- if normalized != answer:
615
  normalization_needed.append(f"Consider using normalized form: '{normalized}'")
616
-
617
  # Check 1: Empty answer
618
- if not answer or not answer.strip():
619
  errors.append("Answer is empty")
620
-
621
  # Check 2: Too long (probably explaining instead of answering)
622
- if len(answer) > 200:
623
  warnings.append("Answer is very long (>200 chars). Consider if question asks for brief response.")
624
-
625
  # Check 3: Contains question words
626
  question_words = ['what', 'who', 'when', 'where', 'why', 'how', 'which']
627
- if any(word in answer.lower() for word in question_words):
628
  warnings.append("Answer contains question words. Make sure you're providing the answer, not rephrasing the question.")
629
-
630
  # Check 4: List ordering
631
- if "," in answer:
632
- items = [item.strip() for item in answer.split(",")]
633
  if len(items) > 1:
634
  warnings.append(f"List detected with {len(items)} items. Verify order matches question requirements.")
635
-
636
  # Check 5: Capitalization consistency
637
- if answer.lower() in ['right', 'left', 'yes', 'no', 'true', 'false']:
638
- if not answer[0].isupper():
639
- normalization_needed.append(f"Consider capitalizing: '{answer.capitalize()}'")
640
-
641
  # Check 6: Abbreviations
642
- if any(abbrev in answer.lower() for abbrev in ['st.', 'dr.', 'mt.']):
643
- if "without abbreviations" in str(answer).lower() or "full" in str(answer).lower():
644
  warnings.append("Question may ask for full form without abbreviations")
645
-
646
  # Check 7: Spacing in lists
647
- if "," in answer:
648
  # Check for inconsistent spacing
649
- if ", " in answer and "," in answer.replace(", ", ""):
650
  normalization_needed.append("Inconsistent spacing in list. Use consistent ', ' format")
651
 
652
  # Build result
@@ -1115,7 +1115,7 @@ Generated Analysis Code:
1115
  **IMPORTANT**: The code above needs column names adjusted.
1116
  Use code_interpreter() with the corrected code to get the answer.
1117
 
1118
- Columns available: {", ".join(pd.read_csv(data_file) if file_ext == '.csv' else pd.read_excel(data_file)).columns.tolist()}
1119
  """
1120
 
1121
  telemetry.record_call("analyze_data_file", time.time() - start_time, True)
@@ -1271,30 +1271,30 @@ class ChessAnalysisInput(BaseModel):
1271
  description: str = Field(description="Context about position", default="")
1272
 
1273
  @tool(args_schema=ChessAnalysisInput)
1274
- def analyze_chess_position(file_path: str) -> str:
1275
  """
1276
  Analyze chess position from image using Gemini Vision + Stockfish.
1277
  Extracts FEN, analyzes best move.
1278
  """
1279
  start_time = time.time()
1280
-
1281
  try:
1282
- print(f"β™ŸοΈ Analyzing chess: {file_path}")
1283
-
1284
  # Find file
1285
- image_path = find_file(file_path)
1286
- if not image_path and os.path.exists(file_path):
1287
- image_path = Path(file_path)
1288
-
1289
- if not image_path or not image_path.exists():
1290
- raise FileNotFoundError(f"Image not found: {file_path}")
1291
 
1292
  GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")
1293
  if not GOOGLE_API_KEY:
1294
  raise ValueError("GEMINI_API_KEY not set")
1295
 
1296
  # Read image as base64
1297
- with open(image_path, "rb") as f:
1298
  image_data = base64.b64encode(f.read()).decode("utf-8")
1299
 
1300
  # Use Gemini to extract FEN
@@ -1431,7 +1431,7 @@ def analyze_image(file_path: str, query: str) -> str:
1431
  message = HumanMessage(
1432
  content=[
1433
  {"type": "text", "text": query},
1434
- {"type": "image_url", "image_url": f"data:image/jpeg;base64,{img_base64}"}
1435
  ]
1436
  )
1437
 
@@ -1885,38 +1885,33 @@ def analyze_video(file_path: str, query: str) -> str:
1885
  GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")
1886
  if not GOOGLE_API_KEY:
1887
  raise ValueError("GEMINI_API_KEY not set")
1888
-
1889
- # Read video as base64
1890
- print(f" Reading video file...")
1891
- with open(video_path, "rb") as f:
1892
- video_data = base64.b64encode(f.read()).decode("utf-8")
1893
-
1894
- # Use Gemini via LangChain
 
 
 
 
 
 
 
 
 
 
1895
  print(f" Analyzing with Gemini...")
1896
- llm = ChatGoogleGenerativeAI(
1897
- model="gemini-2.5-flash",
1898
- google_api_key=GOOGLE_API_KEY,
1899
- temperature=0
1900
- )
1901
-
1902
- # Create message with video
1903
- message = HumanMessage(
1904
- content=[
1905
- {
1906
- "type": "text",
1907
- "text": query
1908
- },
1909
- {
1910
- "type": "video_url",
1911
- "video_url": {
1912
- "url": f"data:video/mp4;base64,{video_data}"
1913
- }
1914
- }
1915
- ]
1916
- )
1917
-
1918
- response = llm.invoke([message])
1919
- result = response.content
1920
 
1921
  print(f"βœ“ Analysis complete: {len(result)} chars")
1922
 
@@ -2364,28 +2359,40 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format.
2364
  Keeps system message + recent history to stay under token limits.
2365
  """
2366
  messages = state.get("messages", [])
2367
-
2368
  # Keep first message (system prompt) + last N messages
2369
- MAX_MESSAGES = 20 # Adjust based on your needs
2370
-
 
 
 
2371
  if len(messages) > MAX_MESSAGES:
2372
  print(f"⚠️ Context pruning: {len(messages)} messages β†’ {MAX_MESSAGES}")
2373
-
2374
- # Always keep system message (if it exists)
2375
  system_msg = None
2376
  if messages and isinstance(messages[0], SystemMessage):
2377
  system_msg = messages[0]
2378
  messages = messages[1:]
2379
-
2380
- # Keep only recent messages
2381
  recent_messages = messages[-(MAX_MESSAGES-1):]
2382
-
2383
- # Reconstruct
2384
  if system_msg:
2385
- state["messages"] = [system_msg] + recent_messages
2386
  else:
2387
- state["messages"] = recent_messages
2388
-
 
 
 
 
 
 
 
 
 
 
 
 
2389
  return state
2390
 
2391
  # Build agent graph
@@ -2468,23 +2475,60 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format.
2468
  print(f"⚠️ Groq error (attempt {attempt+1}): {error_str[:200]}")
2469
 
2470
  # ===== IMPROVED RATE LIMIT HANDLING =====
2471
- # Check for rate limit FIRST
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2472
  if "429" in error_str or "rate limit" in error_str.lower():
2473
  print("❌ Groq rate limit hit!")
2474
-
2475
  if attempt < max_retries - 1:
2476
  wait = 10 * (2 ** attempt) # 10s, 20s, 40s
2477
  print(f" Waiting {wait}s before retry...")
2478
  time.sleep(wait)
2479
  continue
2480
-
2481
- # FINAL FALLBACK: Force search_tool as safe default
2482
- print("πŸ”„ Final attempt failed - using search_tool fallback")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2483
  ai_message = AIMessage(
2484
  content="",
2485
  tool_calls=[ToolCall(
2486
  name="search_tool",
2487
- args={"query": "answer to question"},
2488
  id=str(uuid.uuid4())
2489
  )]
2490
  )
@@ -2649,7 +2693,7 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format.
2649
  all_messages = []
2650
 
2651
  try:
2652
- config_dict = {"recursion_limit": config.MAX_TURNS + 10}
2653
 
2654
  for event in self.graph.stream(graph_input, stream_mode="values", config=config_dict):
2655
  if not event.get('messages'):
@@ -2664,7 +2708,7 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format.
2664
  if tool_call.get("name") == "final_answer_tool":
2665
  args = tool_call.get('args', {})
2666
  if 'answer' in args:
2667
- final_answer = args['answer']
2668
  print(f"\nβœ… FINAL: '{final_answer}'\n")
2669
  break
2670
 
 
602
  start_time = time.time()
603
 
604
  try:
605
+ print(f"βœ“ Validating: '{proposed_answer[:50]}...'")
606
+
607
  warnings = []
608
  errors = []
609
  normalization_needed = []
610
+
611
  # Normalize for validation
612
+ normalized = normalize_answer(proposed_answer)
613
+
614
+ if normalized != proposed_answer:
615
  normalization_needed.append(f"Consider using normalized form: '{normalized}'")
616
+
617
  # Check 1: Empty answer
618
+ if not proposed_answer or not proposed_answer.strip():
619
  errors.append("Answer is empty")
620
+
621
  # Check 2: Too long (probably explaining instead of answering)
622
+ if len(proposed_answer) > 200:
623
  warnings.append("Answer is very long (>200 chars). Consider if question asks for brief response.")
624
+
625
  # Check 3: Contains question words
626
  question_words = ['what', 'who', 'when', 'where', 'why', 'how', 'which']
627
+ if any(word in proposed_answer.lower() for word in question_words):
628
  warnings.append("Answer contains question words. Make sure you're providing the answer, not rephrasing the question.")
629
+
630
  # Check 4: List ordering
631
+ if "," in proposed_answer:
632
+ items = [item.strip() for item in proposed_answer.split(",")]
633
  if len(items) > 1:
634
  warnings.append(f"List detected with {len(items)} items. Verify order matches question requirements.")
635
+
636
  # Check 5: Capitalization consistency
637
+ if proposed_answer.lower() in ['right', 'left', 'yes', 'no', 'true', 'false']:
638
+ if not proposed_answer[0].isupper():
639
+ normalization_needed.append(f"Consider capitalizing: '{proposed_answer.capitalize()}'")
640
+
641
  # Check 6: Abbreviations
642
+ if any(abbrev in proposed_answer.lower() for abbrev in ['st.', 'dr.', 'mt.']):
643
+ if "without abbreviations" in str(proposed_answer).lower() or "full" in str(proposed_answer).lower():
644
  warnings.append("Question may ask for full form without abbreviations")
645
+
646
  # Check 7: Spacing in lists
647
+ if "," in proposed_answer:
648
  # Check for inconsistent spacing
649
+ if ", " in proposed_answer and "," in proposed_answer.replace(", ", ""):
650
  normalization_needed.append("Inconsistent spacing in list. Use consistent ', ' format")
651
 
652
  # Build result
 
1115
  **IMPORTANT**: The code above needs column names adjusted.
1116
  Use code_interpreter() with the corrected code to get the answer.
1117
 
1118
+ Columns available: {", ".join((pd.read_csv(data_file) if file_ext == '.csv' else pd.read_excel(data_file)).columns.tolist())}
1119
  """
1120
 
1121
  telemetry.record_call("analyze_data_file", time.time() - start_time, True)
 
1271
  description: str = Field(description="Context about position", default="")
1272
 
1273
  @tool(args_schema=ChessAnalysisInput)
1274
+ def analyze_chess_position(image_path: str, description: str = "") -> str:
1275
  """
1276
  Analyze chess position from image using Gemini Vision + Stockfish.
1277
  Extracts FEN, analyzes best move.
1278
  """
1279
  start_time = time.time()
1280
+
1281
  try:
1282
+ print(f"β™ŸοΈ Analyzing chess: {image_path}")
1283
+
1284
  # Find file
1285
+ image_path_obj = find_file(image_path)
1286
+ if not image_path_obj and os.path.exists(image_path):
1287
+ image_path_obj = Path(image_path)
1288
+
1289
+ if not image_path_obj or not image_path_obj.exists():
1290
+ raise FileNotFoundError(f"Image not found: {image_path}")
1291
 
1292
  GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")
1293
  if not GOOGLE_API_KEY:
1294
  raise ValueError("GEMINI_API_KEY not set")
1295
 
1296
  # Read image as base64
1297
+ with open(image_path_obj, "rb") as f:
1298
  image_data = base64.b64encode(f.read()).decode("utf-8")
1299
 
1300
  # Use Gemini to extract FEN
 
1431
  message = HumanMessage(
1432
  content=[
1433
  {"type": "text", "text": query},
1434
+ {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}}
1435
  ]
1436
  )
1437
 
 
1885
  GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")
1886
  if not GOOGLE_API_KEY:
1887
  raise ValueError("GEMINI_API_KEY not set")
1888
+
1889
+ # Use Google GenAI SDK directly β€” LangChain wrapper doesn't support video_url
1890
+ import google.generativeai as genai
1891
+ genai.configure(api_key=GOOGLE_API_KEY)
1892
+
1893
+ print(f" Uploading video to Gemini Files API...")
1894
+ video_file = genai.upload_file(str(video_path))
1895
+
1896
+ # Poll until processing is complete
1897
+ import time as _time
1898
+ while video_file.state.name == "PROCESSING":
1899
+ _time.sleep(2)
1900
+ video_file = genai.get_file(video_file.name)
1901
+
1902
+ if video_file.state.name == "FAILED":
1903
+ raise RuntimeError(f"Gemini file processing failed: {video_file.state}")
1904
+
1905
  print(f" Analyzing with Gemini...")
1906
+ model = genai.GenerativeModel("gemini-2.5-flash")
1907
+ response = model.generate_content([query, video_file])
1908
+ result = response.text
1909
+
1910
+ # Clean up uploaded file
1911
+ try:
1912
+ genai.delete_file(video_file.name)
1913
+ except Exception:
1914
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1915
 
1916
  print(f"βœ“ Analysis complete: {len(result)} chars")
1917
 
 
2359
  Keeps system message + recent history to stay under token limits.
2360
  """
2361
  messages = state.get("messages", [])
2362
+
2363
  # Keep first message (system prompt) + last N messages
2364
+ MAX_MESSAGES = 20
2365
+ # ~6000 token limit on Groq; system msg ~3000 chars leaves ~18000 for the rest
2366
+ MAX_TOOL_CONTENT = 1500
2367
+
2368
+ # Prune by count
2369
  if len(messages) > MAX_MESSAGES:
2370
  print(f"⚠️ Context pruning: {len(messages)} messages β†’ {MAX_MESSAGES}")
2371
+
 
2372
  system_msg = None
2373
  if messages and isinstance(messages[0], SystemMessage):
2374
  system_msg = messages[0]
2375
  messages = messages[1:]
2376
+
 
2377
  recent_messages = messages[-(MAX_MESSAGES-1):]
2378
+
 
2379
  if system_msg:
2380
+ messages = [system_msg] + recent_messages
2381
  else:
2382
+ messages = recent_messages
2383
+
2384
+ # Truncate oversized tool outputs to prevent 413 errors
2385
+ pruned = []
2386
+ for msg in messages:
2387
+ if isinstance(msg, ToolMessage) and len(msg.content) > MAX_TOOL_CONTENT:
2388
+ msg = ToolMessage(
2389
+ content=msg.content[:MAX_TOOL_CONTENT] + "...[truncated]",
2390
+ tool_call_id=msg.tool_call_id,
2391
+ name=msg.name
2392
+ )
2393
+ pruned.append(msg)
2394
+
2395
+ state["messages"] = pruned
2396
  return state
2397
 
2398
  # Build agent graph
 
2475
  print(f"⚠️ Groq error (attempt {attempt+1}): {error_str[:200]}")
2476
 
2477
  # ===== IMPROVED RATE LIMIT HANDLING =====
2478
+ # Context too large β€” truncate aggressively and retry immediately
2479
+ if "413" in error_str or "request too large" in error_str.lower():
2480
+ print("❌ Request too large (413) - aggressively pruning context")
2481
+ # Keep system message + last 4 messages, truncate tool content to 1000 chars
2482
+ pruned = []
2483
+ for msg in messages_to_send:
2484
+ if isinstance(msg, SystemMessage):
2485
+ pruned.append(msg)
2486
+ break
2487
+ pruned += messages_to_send[-4:]
2488
+ for msg in pruned:
2489
+ if isinstance(msg, ToolMessage) and len(msg.content) > 1000:
2490
+ msg = ToolMessage(
2491
+ content=msg.content[:1000] + "...[truncated]",
2492
+ tool_call_id=msg.tool_call_id,
2493
+ name=msg.name
2494
+ )
2495
+ messages_to_send = pruned
2496
+ print(f" Pruned to {len(messages_to_send)} messages, retrying...")
2497
+ continue
2498
+
2499
+ # Check for rate limit
2500
  if "429" in error_str or "rate limit" in error_str.lower():
2501
  print("❌ Groq rate limit hit!")
2502
+
2503
  if attempt < max_retries - 1:
2504
  wait = 10 * (2 ** attempt) # 10s, 20s, 40s
2505
  print(f" Waiting {wait}s before retry...")
2506
  time.sleep(wait)
2507
  continue
2508
+
2509
+ # FINAL FALLBACK: If Claude is available use it, otherwise fail fast
2510
+ if self.claude_llm:
2511
+ print("πŸ”„ Groq rate limit - switching to Claude fallback")
2512
+ self.llm_with_tools = self.claude_llm
2513
+ self.current_llm = "claude"
2514
+ try:
2515
+ ai_message = self.claude_llm.invoke(messages_to_send)
2516
+ break
2517
+ except Exception as claude_err:
2518
+ print(f"❌ Claude fallback also failed: {claude_err}")
2519
+
2520
+ # No LLM available β€” extract question and do one targeted search
2521
+ print("πŸ”„ No LLM available - attempting targeted search fallback")
2522
+ question_text = ""
2523
+ for msg in state["messages"]:
2524
+ if isinstance(msg, HumanMessage) and msg.content:
2525
+ question_text = str(msg.content)[:200].strip()
2526
+ break
2527
  ai_message = AIMessage(
2528
  content="",
2529
  tool_calls=[ToolCall(
2530
  name="search_tool",
2531
+ args={"query": question_text or "unknown question"},
2532
  id=str(uuid.uuid4())
2533
  )]
2534
  )
 
2693
  all_messages = []
2694
 
2695
  try:
2696
+ config_dict = {"recursion_limit": config.MAX_TURNS * 2 + 10}
2697
 
2698
  for event in self.graph.stream(graph_input, stream_mode="values", config=config_dict):
2699
  if not event.get('messages'):
 
2708
  if tool_call.get("name") == "final_answer_tool":
2709
  args = tool_call.get('args', {})
2710
  if 'answer' in args:
2711
+ final_answer = normalize_answer(args['answer'])
2712
  print(f"\nβœ… FINAL: '{final_answer}'\n")
2713
  break
2714