gabejavitt commited on
Commit
cbea565
Β·
verified Β·
1 Parent(s): 4961f0a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -39
app.py CHANGED
@@ -21,6 +21,8 @@ from transformers import pipeline
21
  from youtube_transcript_api import YouTubeTranscriptApi
22
  from bs4 import BeautifulSoup
23
  import requests
 
 
24
 
25
  # LangChain & LangGraph
26
  from langgraph.graph.message import add_messages
@@ -32,7 +34,6 @@ from langchain_groq import ChatGroq
32
  from langchain_google_genai import ChatGoogleGenerativeAI
33
  from langchain_community.llms import HuggingFaceHub
34
 
35
-
36
  # RAG
37
  from langchain_text_splitters import RecursiveCharacterTextSplitter
38
  from langchain_community.vectorstores import FAISS
@@ -470,6 +471,85 @@ def audio_transcription_tool(file_path: str) -> str:
470
  return f"Transcription error: {str(e)}"
471
 
472
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  class YoutubeInput(BaseModel):
474
  video_url: str = Field(description="YouTube URL")
475
 
@@ -491,8 +571,6 @@ def get_youtube_transcript(video_url: str) -> str:
491
  if not video_id:
492
  return f"Error: Could not extract video ID."
493
 
494
- from youtube_transcript_api import YouTubeTranscriptApi
495
-
496
  # FIXED: Use get_transcript instead of list_transcripts
497
  transcript_list = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
498
 
@@ -614,6 +692,7 @@ defined_tools = [
614
 
615
  # Specialized
616
  audio_transcription_tool,
 
617
  get_youtube_transcript,
618
  scrape_and_retrieve,
619
 
@@ -722,7 +801,7 @@ def parse_tool_call_from_string(content: str, tools: List) -> List[ToolCall]:
722
 
723
 
724
  # =============================================================================
725
- # CONDITIONAL EDGE FUNCTION (FIXED)
726
  # =============================================================================
727
  def should_continue(state: AgentState):
728
  """Decide next step with robust logic."""
@@ -770,7 +849,6 @@ def should_continue(state: AgentState):
770
 
771
  # 5. Default: continue to agent
772
  print(f"πŸ”„ Default β†’ continuing to agent")
773
- return "agent"
774
 
775
 
776
  # =============================================================================
@@ -1150,14 +1228,35 @@ REMEMBER: One tool per turn. No reasoning without tools. Exact answer format.
1150
  print(f"🎯 NEW QUESTION")
1151
  print(f"{'='*70}")
1152
  print(f"Q: {question[:200]}{'...' if len(question) > 200 else ''}")
 
 
1153
  print(f"{'='*70}\n")
1154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1155
  graph_input = {
1156
  "messages": [
1157
  SystemMessage(content=self.system_prompt),
1158
- HumanMessage(content=question + (f"\n\n[FILE ATTACHED: {file_path}]" if file_path else ""))
1159
  ],
1160
- "file_path": file_path, # Add this to the graph state,
1161
  "turn": 0,
1162
  "has_plan": False,
1163
  "consecutive_errors": 0,
@@ -1287,22 +1386,23 @@ except Exception as e:
1287
  traceback.print_exc()
1288
  agent = None
1289
 
1290
- # ====================================================
1291
- # --- (Original Template Code - Mock Questions Version) ---
1292
- def run_and_submit_all( profile: gr.OAuthProfile | None):
 
1293
  """
1294
  Fetches all questions, runs the BasicAgent on them, submits all answers,
1295
  and displays the results.
1296
  """
1297
- # --- Determine HF Space Runtime URL and Repo URL ---
1298
- space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
1299
 
1300
  if profile:
1301
- username= f"{profile.username}"
1302
  print(f"User logged in: {username}")
1303
  else:
1304
  print("User not logged in.")
1305
  return "Please Login to Hugging Face with the button.", None
 
1306
  # Use the globally instantiated agent
1307
  global agent
1308
 
@@ -1317,13 +1417,6 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
1317
  questions_url = f"{api_url}/questions"
1318
  submit_url = f"{api_url}/submit"
1319
 
1320
- # 1. Instantiate Agent ( modify this part to create your agent)
1321
- #try:
1322
- # agent = BasicAgent()
1323
- #except Exception as e:
1324
- # print(f"Error instantiating agent: {e}")
1325
- # return f"Error initializing agent: {e}", None
1326
- # In the case of an app running as a hugging Face space, this link points toward your codebase ( usefull for others so please keep it public)
1327
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
1328
  print(agent_code)
1329
 
@@ -1357,11 +1450,10 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
1357
  task_id = item.get("task_id")
1358
  question_text = item.get("question")
1359
 
1360
- # 🌟 Initialize file variables for the current question
1361
  local_file_path = None
1362
- file_info = ""
1363
 
1364
- # 🌟 CRITICAL: Check if 'file_path' exists in the item dictionary
1365
  if item.get("file_path"):
1366
  file_path_from_api = item["file_path"]
1367
  file_download_url = f"{DEFAULT_API_URL}/files/{task_id}"
@@ -1370,9 +1462,11 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
1370
  original_filename = file_path_from_api.split('/')[-1]
1371
 
1372
  # Set the path where the file will be saved locally
1373
- local_file_path = os.path.join("/tmp", original_filename)
1374
 
1375
- # --- (Add streaming update here for file download status) ---
 
 
1376
 
1377
  try:
1378
  file_response = requests.get(file_download_url, timeout=15)
@@ -1380,30 +1474,37 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
1380
 
1381
  # Save the raw bytes content to the local file path
1382
  with open(local_file_path, 'wb') as f:
1383
- f.write(file_response.content)
1384
-
1385
- print(f"βœ… Downloaded file to: {local_file_path}")
1386
 
1387
- # Set the context string to be passed to the agent
1388
- file_info = f"\n\n[FILE ATTACHED: {local_file_path}]"
1389
 
 
 
 
 
 
1390
  except requests.exceptions.RequestException as e:
1391
  error_message = f"[FILE DOWNLOAD ERROR: Could not fetch file: {e}]"
1392
  print(f"⚠️ {error_message}")
1393
- # Still provide the error message as context to the agent
1394
- file_info = f"\n\n{error_message}"
 
 
 
1395
 
1396
  if not task_id or question_text is None:
1397
  print(f"Skipping item with missing task_id or question: {item}")
1398
  continue
 
1399
  try:
1400
- question_with_context = question_text + file_info
1401
- #submitted_answer = agent(question_with_context)
1402
- submitted_answer = agent(question_text, local_file_path if item.get("file_path") else None)
1403
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
1404
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
1405
  except Exception as e:
1406
  print(f"Error running agent on task {task_id}: {e}")
 
1407
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
1408
 
1409
  if not answers_payload:
@@ -1480,7 +1581,6 @@ with gr.Blocks() as demo:
1480
  run_button = gr.Button("Run Evaluation & Submit All Answers")
1481
 
1482
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
1483
- # Removed max_rows=10 from DataFrame constructor
1484
  results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
1485
 
1486
  run_button.click(
@@ -1490,9 +1590,8 @@ with gr.Blocks() as demo:
1490
 
1491
  if __name__ == "__main__":
1492
  print("\n" + "-"*30 + " App Starting " + "-"*30)
1493
- # Check for SPACE_HOST and SPACE_ID at startup for information
1494
  space_host_startup = os.getenv("SPACE_HOST")
1495
- space_id_startup = os.getenv("SPACE_ID") # Get SPACE_ID at startup
1496
 
1497
  if space_host_startup:
1498
  print(f"βœ… SPACE_HOST found: {space_host_startup}")
@@ -1500,7 +1599,7 @@ if __name__ == "__main__":
1500
  else:
1501
  print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
1502
 
1503
- if space_id_startup: # Print repo URLs if SPACE_ID is found
1504
  print(f"βœ… SPACE_ID found: {space_id_startup}")
1505
  print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
1506
  print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")
 
21
  from youtube_transcript_api import YouTubeTranscriptApi
22
  from bs4 import BeautifulSoup
23
  import requests
24
+ from PIL import Image
25
+ import base64
26
 
27
  # LangChain & LangGraph
28
  from langgraph.graph.message import add_messages
 
34
  from langchain_google_genai import ChatGoogleGenerativeAI
35
  from langchain_community.llms import HuggingFaceHub
36
 
 
37
  # RAG
38
  from langchain_text_splitters import RecursiveCharacterTextSplitter
39
  from langchain_community.vectorstores import FAISS
 
471
  return f"Transcription error: {str(e)}"
472
 
473
 
474
+ class ImageAnalysisInput(BaseModel):
475
+ file_path: str = Field(description="Image file path")
476
+ query: str = Field(description="What to analyze in the image")
477
+
478
+ @tool(args_schema=ImageAnalysisInput)
479
+ def analyze_image(file_path: str, query: str) -> str:
480
+ """
481
+ Analyzes images using Google Gemini Vision API.
482
+ Use for: chess positions, diagrams, charts, photos, screenshots.
483
+ Provide the EXACT file path from [FILE ATTACHED: ...] in the question.
484
+ """
485
+ if not file_path or not query:
486
+ return "Error: file_path and query required."
487
+
488
+ print(f"πŸ–ΌοΈ Analyzing image: {file_path}")
489
+ print(f" Query: {query[:100]}...")
490
+
491
+ # Try to find the file
492
+ image_path = find_file(file_path)
493
+
494
+ # If not found via find_file, try the path directly (for /tmp files)
495
+ if not image_path and os.path.exists(file_path):
496
+ image_path = Path(file_path)
497
+
498
+ if not image_path or not image_path.exists():
499
+ return f"Error: Image not found at '{file_path}'. Check [FILE ATTACHED: ...] in question for correct path."
500
+
501
+ print(f"βœ“ Found image at: {image_path}")
502
+
503
+ try:
504
+ GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")
505
+ if not GOOGLE_API_KEY:
506
+ return "Error: GEMINI_API_KEY not set."
507
+
508
+ # Load and encode image
509
+ img = Image.open(image_path)
510
+ print(f" Image size: {img.size}, mode: {img.mode}")
511
+
512
+ # Convert to RGB if necessary
513
+ if img.mode not in ['RGB', 'RGBA']:
514
+ img = img.convert('RGB')
515
+
516
+ # Convert to base64
517
+ buffered = io.BytesIO()
518
+ img.save(buffered, format="JPEG")
519
+ img_base64 = base64.b64encode(buffered.getvalue()).decode()
520
+
521
+ print(f" Encoded image: {len(img_base64)} bytes")
522
+
523
+ # Use Gemini Vision
524
+ vision_llm = ChatGoogleGenerativeAI(
525
+ model="gemini-2.0-flash-exp",
526
+ google_api_key=GOOGLE_API_KEY,
527
+ temperature=0
528
+ )
529
+
530
+ message = HumanMessage(
531
+ content=[
532
+ {"type": "text", "text": query},
533
+ {
534
+ "type": "image_url",
535
+ "image_url": f"data:image/jpeg;base64,{img_base64}"
536
+ }
537
+ ]
538
+ )
539
+
540
+ print(f" Sending to Gemini Vision...")
541
+ response = vision_llm.invoke([message])
542
+ print(f"βœ“ Got response: {len(response.content)} chars")
543
+
544
+ return f"Image Analysis:\n{truncate_if_needed(response.content)}"
545
+
546
+ except Exception as e:
547
+ error_msg = f"Image analysis error: {str(e)}"
548
+ print(f"❌ {error_msg}")
549
+ print(traceback.format_exc())
550
+ return error_msg
551
+
552
+
553
  class YoutubeInput(BaseModel):
554
  video_url: str = Field(description="YouTube URL")
555
 
 
571
  if not video_id:
572
  return f"Error: Could not extract video ID."
573
 
 
 
574
  # FIXED: Use get_transcript instead of list_transcripts
575
  transcript_list = YouTubeTranscriptApi.get_transcript(video_id, languages=['en'])
576
 
 
692
 
693
  # Specialized
694
  audio_transcription_tool,
695
+ analyze_image, # NEW: Image analysis tool
696
  get_youtube_transcript,
697
  scrape_and_retrieve,
698
 
 
801
 
802
 
803
  # =============================================================================
804
+ # CONDITIONAL EDGE FUNCTION
805
  # =============================================================================
806
  def should_continue(state: AgentState):
807
  """Decide next step with robust logic."""
 
849
 
850
  # 5. Default: continue to agent
851
  print(f"πŸ”„ Default β†’ continuing to agent")
 
852
 
853
 
854
  # =============================================================================
 
1228
  print(f"🎯 NEW QUESTION")
1229
  print(f"{'='*70}")
1230
  print(f"Q: {question[:200]}{'...' if len(question) > 200 else ''}")
1231
+ if file_path:
1232
+ print(f"πŸ“Ž File attached: {file_path}")
1233
  print(f"{'='*70}\n")
1234
 
1235
+ # Enhanced question context with file information
1236
+ question_text = question
1237
+ if file_path:
1238
+ file_ext = Path(file_path).suffix.lower()
1239
+ file_type = "unknown"
1240
+
1241
+ if file_ext in ['.jpg', '.jpeg', '.png', '.gif', '.bmp']:
1242
+ file_type = "image"
1243
+ elif file_ext in ['.mp3', '.wav', '.m4a', '.flac']:
1244
+ file_type = "audio"
1245
+ elif file_ext in ['.csv', '.xlsx', '.xls']:
1246
+ file_type = "data"
1247
+ elif file_ext in ['.txt', '.pdf', '.doc', '.docx']:
1248
+ file_type = "document"
1249
+
1250
+ question_text += f"\n\n[FILE ATTACHED: {file_path}]"
1251
+ question_text += f"\n[FILE TYPE: {file_type}]"
1252
+ question_text += f"\nIMPORTANT: Use the appropriate tool to access this file first!"
1253
+
1254
  graph_input = {
1255
  "messages": [
1256
  SystemMessage(content=self.system_prompt),
1257
+ HumanMessage(content=question_text)
1258
  ],
1259
+ "file_path": file_path,
1260
  "turn": 0,
1261
  "has_plan": False,
1262
  "consecutive_errors": 0,
 
1386
  traceback.print_exc()
1387
  agent = None
1388
 
1389
+ # =============================================================================
1390
+ # RUN AND SUBMIT FUNCTION
1391
+ # =============================================================================
1392
+ def run_and_submit_all(profile: gr.OAuthProfile | None):
1393
  """
1394
  Fetches all questions, runs the BasicAgent on them, submits all answers,
1395
  and displays the results.
1396
  """
1397
+ space_id = os.getenv("SPACE_ID")
 
1398
 
1399
  if profile:
1400
+ username = f"{profile.username}"
1401
  print(f"User logged in: {username}")
1402
  else:
1403
  print("User not logged in.")
1404
  return "Please Login to Hugging Face with the button.", None
1405
+
1406
  # Use the globally instantiated agent
1407
  global agent
1408
 
 
1417
  questions_url = f"{api_url}/questions"
1418
  submit_url = f"{api_url}/submit"
1419
 
 
 
 
 
 
 
 
1420
  agent_code = f"https://huggingface.co/spaces/{space_id}/tree/main"
1421
  print(agent_code)
1422
 
 
1450
  task_id = item.get("task_id")
1451
  question_text = item.get("question")
1452
 
1453
+ # Initialize file variables for the current question
1454
  local_file_path = None
 
1455
 
1456
+ # Check if 'file_path' exists in the item dictionary
1457
  if item.get("file_path"):
1458
  file_path_from_api = item["file_path"]
1459
  file_download_url = f"{DEFAULT_API_URL}/files/{task_id}"
 
1462
  original_filename = file_path_from_api.split('/')[-1]
1463
 
1464
  # Set the path where the file will be saved locally
1465
+ local_file_path = os.path.join("/tmp", original_filename)
1466
 
1467
+ print(f"πŸ“₯ Downloading file for task {task_id}...")
1468
+ print(f" URL: {file_download_url}")
1469
+ print(f" Saving to: {local_file_path}")
1470
 
1471
  try:
1472
  file_response = requests.get(file_download_url, timeout=15)
 
1474
 
1475
  # Save the raw bytes content to the local file path
1476
  with open(local_file_path, 'wb') as f:
1477
+ f.write(file_response.content)
 
 
1478
 
1479
+ file_size = os.path.getsize(local_file_path)
1480
+ print(f"βœ… Downloaded file: {original_filename} ({file_size} bytes)")
1481
 
1482
+ # Verify file exists and is readable
1483
+ if not os.path.exists(local_file_path):
1484
+ print(f"⚠️ Warning: File saved but cannot be found at {local_file_path}")
1485
+ local_file_path = None
1486
+
1487
  except requests.exceptions.RequestException as e:
1488
  error_message = f"[FILE DOWNLOAD ERROR: Could not fetch file: {e}]"
1489
  print(f"⚠️ {error_message}")
1490
+ local_file_path = None
1491
+ except Exception as e:
1492
+ error_message = f"[FILE SAVE ERROR: {e}]"
1493
+ print(f"⚠️ {error_message}")
1494
+ local_file_path = None
1495
 
1496
  if not task_id or question_text is None:
1497
  print(f"Skipping item with missing task_id or question: {item}")
1498
  continue
1499
+
1500
  try:
1501
+ # Pass file_path to agent
1502
+ submitted_answer = agent(question_text, local_file_path)
 
1503
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
1504
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
1505
  except Exception as e:
1506
  print(f"Error running agent on task {task_id}: {e}")
1507
+ print(traceback.format_exc())
1508
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
1509
 
1510
  if not answers_payload:
 
1581
  run_button = gr.Button("Run Evaluation & Submit All Answers")
1582
 
1583
  status_output = gr.Textbox(label="Run Status / Submission Result", lines=5, interactive=False)
 
1584
  results_table = gr.DataFrame(label="Questions and Agent Answers", wrap=True)
1585
 
1586
  run_button.click(
 
1590
 
1591
  if __name__ == "__main__":
1592
  print("\n" + "-"*30 + " App Starting " + "-"*30)
 
1593
  space_host_startup = os.getenv("SPACE_HOST")
1594
+ space_id_startup = os.getenv("SPACE_ID")
1595
 
1596
  if space_host_startup:
1597
  print(f"βœ… SPACE_HOST found: {space_host_startup}")
 
1599
  else:
1600
  print("ℹ️ SPACE_HOST environment variable not found (running locally?).")
1601
 
1602
+ if space_id_startup:
1603
  print(f"βœ… SPACE_ID found: {space_id_startup}")
1604
  print(f" Repo URL: https://huggingface.co/spaces/{space_id_startup}")
1605
  print(f" Repo Tree URL: https://huggingface.co/spaces/{space_id_startup}/tree/main")