Spaces:
Sleeping
Sleeping
| # Standard Library | |
| import os | |
| import re | |
| import tempfile | |
| import string | |
| import glob | |
| import shutil | |
| import gc | |
| import sys | |
| import uuid | |
| import signal | |
| from pathlib import Path | |
| import subprocess | |
| from datetime import datetime | |
| from io import BytesIO | |
| from contextlib import contextmanager | |
| from langchain_huggingface import HuggingFacePipeline | |
| from typing import TypedDict, List, Optional, Dict, Any, Annotated, Literal, Union, Tuple, Set, Type | |
| import time | |
| from collections import Counter | |
| from pydantic import Field, BaseModel, Extra | |
| import hashlib | |
| import json | |
| import numpy as np | |
| import ast | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from collections import Counter, defaultdict | |
| # Third-Party Packages | |
| import cv2 | |
| import requests | |
| import wikipedia | |
| import spacy | |
| import yt_dlp | |
| import librosa | |
| from PIL import Image | |
| from bs4 import BeautifulSoup | |
| from duckduckgo_search import DDGS | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline, AutoTokenizer | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from sklearn.cluster import KMeans | |
| from sklearn.preprocessing import StandardScaler | |
| import speech_recognition as sr | |
| from pydub import AudioSegment | |
| from pydub.silence import split_on_silence | |
| import nltk | |
| from nltk.corpus import words | |
| import pandas as pd | |
| # LangChain Ecosystem | |
| from langchain.docstore.document import Document | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.schema import Document | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, BaseMessage, SystemMessage, ToolMessage | |
| from langchain_core.tools import BaseTool, StructuredTool, tool, render_text_description | |
| from langchain_core.documents import Document | |
| from langchain.callbacks.manager import ( | |
| AsyncCallbackManagerForToolRun, | |
| CallbackManagerForToolRun, | |
| ) | |
| # LangGraph | |
| from langgraph.graph import START, END, StateGraph | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| # PyTorch | |
| import torch | |
| from functools import partial | |
| from transformers import pipeline | |
| # Additional Utilities | |
| from datetime import datetime | |
| from urllib.parse import urljoin, urlparse | |
| import logging | |
| nlp = spacy.load("en_core_web_sm") | |
| # Ensure the word list is downloaded | |
| nltk.download('words', quiet=True) | |
| english_words = set(words.words()) | |
| logger = logging.getLogger(__name__) | |
| logging.basicConfig(filename='app.log', level=logging.INFO) | |
| # --- Model Configuration --- | |
| def create_llm_pipeline(): | |
| #model_id = "meta-llama/Llama-2-13b-chat-hf" | |
| #model_id = "meta-llama/Llama-3.3-70B-Instruct" | |
| #model_id = "mistralai/Mistral-Small-24B-Base-2501" | |
| model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
| #model_id = "Meta-Llama/Llama-2-7b-chat-hf" | |
| #model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO" | |
| #model_id = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF" | |
| #model_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
| #model_id = "Qwen/Qwen2-7B-Instruct" | |
| #model_id = "GSAI-ML/LLaDA-8B-Instruct" | |
| return pipeline( | |
| "text-generation", | |
| model=model_id, | |
| device_map="cpu", | |
| torch_dtype=torch.float16, | |
| max_new_tokens=1024, | |
| temperature=0.05, | |
| do_sample=False, | |
| repetition_penalty=1.2 | |
| ) | |
| # Define file extension sets for each category | |
| PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} | |
| AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'} | |
| CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'} | |
| SPREADSHEET_EXTENSIONS = { | |
| '.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm', | |
| '.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet', | |
| '.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2', | |
| '.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx', | |
| '.pmv', '.uos', '.txt' | |
| } | |
| def get_file_type(filename: str) -> str: | |
| if not filename or '.' not in filename or filename == '': | |
| return '' | |
| ext = filename.lower().rsplit('.', 1)[-1] | |
| dot_ext = f'.{ext}' | |
| if dot_ext in PICTURE_EXTENSIONS: | |
| return 'picture' | |
| elif dot_ext in AUDIO_EXTENSIONS: | |
| return 'audio' | |
| elif dot_ext in CODE_EXTENSIONS: | |
| return 'code' | |
| elif dot_ext in SPREADSHEET_EXTENSIONS: | |
| return 'spreadsheet' | |
| else: | |
| return 'unknown' | |
| def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str: | |
| """ | |
| Writes bytes to a file in the system temporary directory using the provided file_name. | |
| Returns the full path to the saved file. | |
| The file will persist until manually deleted or the OS cleans the temp directory. | |
| """ | |
| temp_dir = "/tmp" # /tmp is always writable in Hugging Face Spaces | |
| os.makedirs(temp_dir, exist_ok=True) | |
| file_path = os.path.join(temp_dir, file_name) | |
| with open(file_path, 'wb') as f: | |
| f.write(file_bytes) | |
| print(f"File written to: {file_path}") | |
| return file_path | |
| def extract_final_answer(text: str) -> str: | |
| """ | |
| Extracts the answer after the last 'FINAL ANSWER:' (case-insensitive), | |
| removes any parenthetical immediately following a numeric answer, | |
| strips trailing punctuation, sorts comma-separated lists, | |
| and does not split numbers containing commas. | |
| Returns an empty string if marker not found. | |
| """ | |
| marker = "FINAL ANSWER:" | |
| idx = text.lower().rfind(marker.lower()) | |
| if idx == -1: | |
| return "" | |
| # Extract answer after marker | |
| result = text[idx + len(marker):].strip() | |
| if "pure vanilla extract" not in result: | |
| result = result.replace("vanilla extract", "pure vanilla extract") | |
| # Remove parenthetical immediately following a number at the start | |
| result = re.sub(r'^(\d+(?:\.\d+)?)\s*\(.*?\)', r'\1', result) | |
| # Remove trailing punctuation and whitespace | |
| result = result.rstrip(string.punctuation + " ") | |
| # Split on commas NOT between digits (i.e., not inside numbers) | |
| # This regex splits on commas not surrounded by digits (to avoid splitting numbers like 1,000) | |
| items = re.split(r',(?!\s*\d{3}\b)', result) | |
| # If we have a list, sort it | |
| if len(items) > 1: | |
| items = [item.strip() for item in items] | |
| # Try to sort numerically | |
| try: | |
| sorted_items = sorted( | |
| items, | |
| key=lambda x: float(re.sub(r'[^\d\.]', '', x)) # Remove non-numeric except . | |
| ) | |
| return ', '.join(sorted_items) | |
| except ValueError: | |
| # Fallback: sort alphabetically | |
| sorted_items = sorted(items, key=lambda x: x.lower()) | |
| return ', '.join(sorted_items) | |
| return result | |
| class AudioTranscriptionInput(BaseModel): | |
| """Input schema for AudioTranscriptionTool.""" | |
| file_path: str = Field(description="Path to the audio file to transcribe") | |
| engine: Optional[str] = Field(default="google", description="Speech recognition engine to use") | |
| language: Optional[str] = Field(default="en-US", description="Language of the audio") | |
| class AudioTranscriptionTool(BaseTool): | |
| """Tool for transcribing audio files using local speech recognition.""" | |
| name: str = "audio_transcription" | |
| description: str = """ | |
| Transcribes voice memo, audio files (mp3, wav, m4a, flac, etc.) to text using local speech recognition. | |
| Input should be a dictionary with 'file_path' key containing the path to the audio file. | |
| Optionally accepts 'engine' and 'language' parameters. | |
| Returns the transcribed text as a string. | |
| """ | |
| args_schema: type[BaseModel] = AudioTranscriptionInput | |
| class Config: | |
| arbitrary_types_allowed = True | |
| def __init__(self, **kwargs): | |
| """Initialize the AudioTranscriptionTool.""" | |
| super().__init__(**kwargs) | |
| self._init_speech_recognition() | |
| def _init_speech_recognition(self): | |
| """Initialize speech recognition components.""" | |
| try: | |
| import speech_recognition as sr | |
| from pydub import AudioSegment | |
| object.__setattr__(self, 'recognizer', sr.Recognizer()) | |
| object.__setattr__(self, 'sr', sr) | |
| object.__setattr__(self, 'AudioSegment', AudioSegment) | |
| except ImportError as e: | |
| raise ImportError( | |
| "Required libraries not found. Install with: " | |
| "pip install SpeechRecognition pydub" | |
| ) from e | |
| def _validate_audio_file(self, file_path: str) -> bool: | |
| """Validate that the audio file exists and has a supported format.""" | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError(f"Audio file not found: {file_path}") | |
| # Check file extension - pydub supports many formats | |
| supported_formats = {'.mp3', '.wav', '.m4a', '.flac', '.mp4', '.mpeg', '.mpga', '.webm', '.ogg', '.aac'} | |
| file_extension = Path(file_path).suffix.lower() | |
| if file_extension not in supported_formats: | |
| raise ValueError( | |
| f"Unsupported audio format: {file_extension}. " | |
| f"Supported formats: {', '.join(supported_formats)}" | |
| ) | |
| return True | |
| def _convert_to_wav(self, file_path: str) -> str: | |
| """Convert audio file to WAV format if needed.""" | |
| file_extension = Path(file_path).suffix.lower() | |
| if file_extension == '.wav': | |
| return file_path | |
| try: | |
| # Convert to WAV using pydub | |
| audio = self.AudioSegment.from_file(file_path) | |
| # Create temporary WAV file | |
| temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') | |
| audio.export(temp_wav.name, format="wav") | |
| return temp_wav.name | |
| except Exception as e: | |
| raise RuntimeError(f"Error converting audio file to WAV: {str(e)}") | |
| def _transcribe_audio(self, file_path: str, engine: str = "google", language: str = "en-US") -> str: | |
| """Transcribe audio file using local speech recognition.""" | |
| temp_wav_path = None | |
| try: | |
| # Convert to WAV if necessary | |
| wav_path = self._convert_to_wav(file_path) | |
| if wav_path != file_path: | |
| temp_wav_path = wav_path | |
| # Load audio file | |
| with self.sr.AudioFile(wav_path) as source: | |
| # Adjust for ambient noise | |
| self.recognizer.adjust_for_ambient_noise(source, duration=0.5) | |
| # Record the audio | |
| audio_data = self.recognizer.record(source) | |
| # Choose recognition engine | |
| if engine == "google": | |
| transcript = self.recognizer.recognize_google(audio_data, language=language) | |
| elif engine == "sphinx": | |
| transcript = self.recognizer.recognize_sphinx(audio_data, language=language) | |
| elif engine == "wit": | |
| # Note: requires WIT_AI_KEY environment variable | |
| wit_key = os.getenv('WIT_AI_KEY') | |
| if not wit_key: | |
| raise ValueError("WIT_AI_KEY environment variable required for Wit.ai engine") | |
| transcript = self.recognizer.recognize_wit(audio_data, key=wit_key) | |
| elif engine == "bing": | |
| # Note: requires BING_KEY environment variable | |
| bing_key = os.getenv('BING_KEY') | |
| if not bing_key: | |
| raise ValueError("BING_KEY environment variable required for Bing engine") | |
| transcript = self.recognizer.recognize_bing(audio_data, key=bing_key, language=language) | |
| else: | |
| # Default to Google | |
| transcript = self.recognizer.recognize_google(audio_data, language=language) | |
| return transcript | |
| except self.sr.UnknownValueError: | |
| return "Could not understand the audio - speech was unclear or inaudible" | |
| except self.sr.RequestError as e: | |
| return f"Error with speech recognition service: {str(e)}" | |
| except Exception as e: | |
| raise RuntimeError(f"Error transcribing audio: {str(e)}") | |
| finally: | |
| # Clean up temporary WAV file | |
| if temp_wav_path and os.path.exists(temp_wav_path): | |
| try: | |
| os.unlink(temp_wav_path) | |
| except OSError: | |
| pass # Ignore cleanup errors | |
| def _run(self, file_path: str, engine: str = "google", language: str = "en-US", **kwargs) -> str: | |
| """ | |
| Internal method required by LangChain BaseTool. | |
| Args: | |
| file_path: Path to the audio file to transcribe | |
| engine: Speech recognition engine to use | |
| language: Language of the audio | |
| Returns: | |
| str: Transcribed text from the audio file | |
| """ | |
| try: | |
| # Validate audio file | |
| self._validate_audio_file(file_path) | |
| # Transcribe audio | |
| transcript = self._transcribe_audio( | |
| file_path=file_path, | |
| engine=engine, | |
| language=language | |
| ) | |
| return transcript | |
| except Exception as e: | |
| error_msg = f"AudioTranscriptionTool error: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| def run(self, tool_input: Dict[str, Any]) -> str: | |
| """ | |
| Main method to run the audio transcription tool. | |
| Args: | |
| tool_input: Dictionary containing 'file_path' and optional parameters | |
| Returns: | |
| str: Transcribed text from the audio file | |
| """ | |
| try: | |
| # Extract parameters from input | |
| file_path = tool_input.get('file_path') | |
| if not file_path: | |
| raise ValueError("file_path is required in tool_input") | |
| engine = tool_input.get('engine', 'google') | |
| language = tool_input.get('language', 'en-US') | |
| # Call the internal _run method | |
| return self._run(file_path=file_path, engine=engine, language=language) | |
| except Exception as e: | |
| error_msg = f"AudioTranscriptionTool error: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| # Enhanced local transcription tool with multiple engine support | |
| class AdvancedAudioTranscriptionTool(BaseTool): | |
| """Advanced tool with support for multiple local transcription engines including Whisper.""" | |
| name: str = "advanced_audio_transcription" | |
| description: str = """ | |
| Advanced audio transcription tool supporting multiple engines including local Whisper. | |
| Supports engines: 'whisper' (local), 'google', 'sphinx', 'wit', 'bing'. | |
| Input should be a dictionary with 'file_path' key. | |
| Returns the transcribed text as a string. | |
| """ | |
| args_schema: type[BaseModel] = AudioTranscriptionInput | |
| class Config: | |
| arbitrary_types_allowed = True | |
| def __init__(self, **kwargs): | |
| """Initialize the AdvancedAudioTranscriptionTool.""" | |
| super().__init__(**kwargs) | |
| self._init_speech_recognition() | |
| self._init_whisper() | |
| def _init_speech_recognition(self): | |
| """Initialize speech recognition components.""" | |
| try: | |
| import speech_recognition as sr | |
| from pydub import AudioSegment | |
| object.__setattr__(self, 'recognizer', sr.Recognizer()) | |
| object.__setattr__(self, 'sr', sr) | |
| object.__setattr__(self, 'AudioSegment', AudioSegment) | |
| except ImportError as e: | |
| raise ImportError( | |
| "Required libraries not found. Install with: " | |
| "pip install SpeechRecognition pydub" | |
| ) from e | |
| def _init_whisper(self): | |
| """Initialize Whisper if available.""" | |
| try: | |
| import whisper | |
| object.__setattr__(self, 'whisper', whisper) | |
| except ImportError: | |
| object.__setattr__(self, 'whisper', None) | |
| print("Warning: OpenAI Whisper not installed. Install with 'pip install openai-whisper' for local Whisper support.") | |
| def _validate_audio_file(self, file_path: str) -> bool: | |
| """Validate that the audio file exists and has a supported format.""" | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError(f"Audio file not found: {file_path}") | |
| supported_formats = {'.mp3', '.wav', '.m4a', '.flac', '.mp4', '.mpeg', '.mpga', '.webm', '.ogg', '.aac'} | |
| file_extension = Path(file_path).suffix.lower() | |
| if file_extension not in supported_formats: | |
| raise ValueError( | |
| f"Unsupported audio format: {file_extension}. " | |
| f"Supported formats: {', '.join(supported_formats)}" | |
| ) | |
| return True | |
| def _transcribe_with_whisper(self, file_path: str, language: str = "en") -> str: | |
| """Transcribe using local Whisper model.""" | |
| if not self.whisper: | |
| raise RuntimeError("Whisper not installed. Install with 'pip install openai-whisper'") | |
| try: | |
| # Load the model (you can change model size: tiny, base, small, medium, large) | |
| model = self.whisper.load_model("base") | |
| # Transcribe the audio | |
| result = model.transcribe(file_path, language=language if language != "en-US" else "en") | |
| return result["text"].strip() | |
| except Exception as e: | |
| raise RuntimeError(f"Error with Whisper transcription: {str(e)}") | |
| def _convert_to_wav(self, file_path: str) -> str: | |
| """Convert audio file to WAV format if needed.""" | |
| file_extension = Path(file_path).suffix.lower() | |
| if file_extension == '.wav': | |
| return file_path | |
| try: | |
| audio = self.AudioSegment.from_file(file_path) | |
| temp_wav = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') | |
| audio.export(temp_wav.name, format="wav") | |
| return temp_wav.name | |
| except Exception as e: | |
| raise RuntimeError(f"Error converting audio file to WAV: {str(e)}") | |
| def _transcribe_with_sr(self, file_path: str, engine: str = "google", language: str = "en-US") -> str: | |
| """Transcribe using speech_recognition library.""" | |
| temp_wav_path = None | |
| try: | |
| wav_path = self._convert_to_wav(file_path) | |
| if wav_path != file_path: | |
| temp_wav_path = wav_path | |
| with self.sr.AudioFile(wav_path) as source: | |
| self.recognizer.adjust_for_ambient_noise(source, duration=0.5) | |
| audio_data = self.recognizer.record(source) | |
| if engine == "google": | |
| transcript = self.recognizer.recognize_google(audio_data, language=language) | |
| elif engine == "sphinx": | |
| transcript = self.recognizer.recognize_sphinx(audio_data) | |
| elif engine == "wit": | |
| wit_key = os.getenv('WIT_AI_KEY') | |
| if not wit_key: | |
| raise ValueError("WIT_AI_KEY environment variable required for Wit.ai engine") | |
| transcript = self.recognizer.recognize_wit(audio_data, key=wit_key) | |
| elif engine == "bing": | |
| bing_key = os.getenv('BING_KEY') | |
| if not bing_key: | |
| raise ValueError("BING_KEY environment variable required for Bing engine") | |
| transcript = self.recognizer.recognize_bing(audio_data, key=bing_key, language=language) | |
| else: | |
| transcript = self.recognizer.recognize_google(audio_data, language=language) | |
| return transcript | |
| except self.sr.UnknownValueError: | |
| return "Could not understand the audio - speech was unclear or inaudible" | |
| except self.sr.RequestError as e: | |
| return f"Error with speech recognition service: {str(e)}" | |
| finally: | |
| if temp_wav_path and os.path.exists(temp_wav_path): | |
| try: | |
| os.unlink(temp_wav_path) | |
| except OSError: | |
| pass | |
| def _run(self, file_path: str, engine: str = "google", language: str = "en-US", **kwargs) -> str: | |
| """ | |
| Internal method required by LangChain BaseTool. | |
| Args: | |
| file_path: Path to the audio file to transcribe | |
| engine: Speech recognition engine to use | |
| language: Language of the audio | |
| Returns: | |
| str: Transcribed text from the audio file | |
| """ | |
| try: | |
| self._validate_audio_file(file_path) | |
| # Use local Whisper if specified | |
| if engine == "whisper": | |
| transcript = self._transcribe_with_whisper(file_path, language) | |
| else: | |
| # Use speech_recognition library | |
| transcript = self._transcribe_with_sr(file_path, engine, language) | |
| return transcript | |
| except Exception as e: | |
| error_msg = f"AdvancedAudioTranscriptionTool error: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| def run(self, tool_input: Dict[str, Any]) -> str: | |
| """ | |
| Main method to run the advanced audio transcription tool. | |
| Args: | |
| tool_input: Dictionary containing 'file_path' and optional parameters | |
| Returns: | |
| str: Transcribed text from the audio file | |
| """ | |
| try: | |
| file_path = tool_input.get('file_path') | |
| if not file_path: | |
| raise ValueError("file_path is required in tool_input") | |
| engine = tool_input.get('engine', 'google') | |
| language = tool_input.get('language', 'en-US') | |
| # Call the internal _run method | |
| return self._run(file_path=file_path, engine=engine, language=language) | |
| except Exception as e: | |
| error_msg = f"AdvancedAudioTranscriptionTool error: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| class ExcelReaderInput(BaseModel): | |
| """Input schema for ExcelReaderTool.""" | |
| file_path: str = Field(description="Path to the Excel file to read") | |
| class ExcelReaderTool(BaseTool): | |
| """Tool for reading Excel files and formatting them for LLM consumption.""" | |
| name: str = "excel_reader" | |
| description: str = ( | |
| "Reads an Excel file from the specified file path " | |
| "Use for running math operations on a table of sales data from a fast-food restaurant chain ONLY," | |
| ) | |
| args_schema: Type[BaseModel] = ExcelReaderInput | |
| def _run(self, file_path: str, run_manager: Optional[Any] = None) -> str: | |
| """ | |
| Execute the tool to read Excel file and return formatted table. | |
| Args: | |
| file_path: Path to the Excel file | |
| run_manager: Optional callback manager | |
| Returns: | |
| Formatted string representation of the Excel table | |
| """ | |
| try: | |
| # Validate file exists | |
| if not os.path.exists(file_path): | |
| return f"Error: File not found at path: {file_path}" | |
| # Validate file extension | |
| if not file_path.lower().endswith(('.xlsx', '.xls')): | |
| return f"Error: File must be an Excel file (.xlsx or .xls). Got: {file_path}" | |
| # Read Excel file - specifically Sheet1 | |
| try: | |
| df = pd.read_excel(file_path, sheet_name='Sheet1') | |
| except ValueError as e: | |
| if "Worksheet named 'Sheet1' not found" in str(e): | |
| # If Sheet1 doesn't exist, try reading the first sheet | |
| df = pd.read_excel(file_path, sheet_name=0) | |
| else: | |
| raise e | |
| # Check if dataframe is empty | |
| if df.empty: | |
| return "The Excel file contains no data in Sheet1." | |
| # Format the table for LLM consumption | |
| formatted_output = self._format_table_for_llm(df, file_path) | |
| return formatted_output | |
| except FileNotFoundError: | |
| return f"Error: File not found at path: {file_path}" | |
| except PermissionError: | |
| return f"Error: Permission denied accessing file: {file_path}" | |
| except Exception as e: | |
| return f"Error reading Excel file: {str(e)}" | |
| def _format_table_for_llm(self, df: pd.DataFrame, file_path: str) -> str: | |
| """ | |
| Format the pandas DataFrame into a readable string format for LLMs. | |
| Args: | |
| df: The pandas DataFrame containing the Excel data | |
| file_path: Original file path for reference | |
| Returns: | |
| Formatted string representation of the table | |
| """ | |
| output_lines = [] | |
| # Add header information | |
| #output_lines.append(f"EXCEL FILE DATA FROM: {os.path.basename(file_path)}") | |
| #output_lines.append(f"Sheet: Sheet1") | |
| #output_lines.append(f"Dimensions: {df.shape[0]} rows × {df.shape[1]} columns") | |
| #output_lines.append("-" * 60) | |
| # Add column information | |
| #output_lines.append("COLUMNS:") | |
| #for i, col in enumerate(df.columns, 1): | |
| # col_type = str(df[col].dtype) | |
| # non_null_count = df[col].count() | |
| # output_lines.append(f" {i}. {col} ({col_type}) - {non_null_count} non-null values") | |
| #output_lines.append("-" * 60) | |
| # Add table data in a clean format | |
| output_lines.append("TABLE DATA:") | |
| # Convert DataFrame to string with proper formatting | |
| # Handle potential NaN values and make it readable | |
| df_clean = df.fillna("N/A") # Replace NaN with readable placeholder | |
| # Create a formatted table string | |
| #table_str = df_clean.to_string(index=True, max_rows=None, max_cols=None) | |
| #output_lines.append(table_str) | |
| # Add summary statistics for numeric columns if they exist | |
| numeric_cols = df.select_dtypes(include=['number']).columns | |
| sums = df_clean[numeric_cols].sum() | |
| # Step 2: Define which columns are food and which are drink | |
| food_cols = [col for col in numeric_cols if col.lower() != 'soda'] | |
| drink_cols = [col for col in numeric_cols if col.lower() == 'soda'] | |
| # Step 3: Aggregate totals | |
| food_total = sums[food_cols].sum() | |
| drink_total = sums[drink_cols].sum() | |
| # Step 4: Format the results as dollars | |
| formatted_totals = { | |
| 'Food': f"{food_total:.2f}", | |
| 'Drink': f"{drink_total:.2f}" | |
| } | |
| # Step 5: Convert to string for display (optional) | |
| result_string = '\n'.join([f"{k}: {v}" for k, v in formatted_totals.items()]) | |
| # Convert to string for display | |
| #result_string = formatted.to_string() | |
| output_lines.append(result_string) | |
| #output_lines.append(df_clean[numeric_cols].sum()) | |
| if len(numeric_cols) > 0: | |
| output_lines.append("-" * 60) | |
| #output_lines.append("NUMERIC COLUMN SUMMARY:") | |
| #for col in numeric_cols: | |
| # stats = df[col].describe() | |
| # output_lines.append(f"\n{col}:") | |
| # output_lines.append(f" Count: {stats['count']}") | |
| # output_lines.append(f" Mean: {stats['mean']:.2f}") | |
| # output_lines.append(f" Min: {stats['min']}") | |
| # output_lines.append(f" Max: {stats['max']}") | |
| return "\n".join(output_lines) | |
| async def _arun(self, file_path: str, run_manager: Optional[Any] = None) -> str: | |
| """Async version of the tool (falls back to sync implementation).""" | |
| return self._run(file_path, run_manager) | |
| class PythonExecutorInput(BaseModel): | |
| """Input schema for PythonExecutor tool.""" | |
| file_path: str = Field(description="Path to the Python file to execute") | |
| class PythonExecutorTool(BaseTool): | |
| """Tool that executes a Python file and returns the result.""" | |
| name: str = "python_executor" | |
| description: str = "Executes a Python file from the given file path and returns the output" | |
| args_schema: Type[BaseModel] = PythonExecutorInput | |
| def _run( | |
| self, | |
| file_path: str, | |
| run_manager: Optional[Any] = None, | |
| ) -> str: | |
| """Execute the Python file and return the result.""" | |
| try: | |
| # Validate that the file exists | |
| if not os.path.exists(file_path): | |
| return f"Error: File '{file_path}' does not exist" | |
| # Validate that it's a Python file | |
| if not file_path.endswith('.py'): | |
| return f"Error: '{file_path}' is not a Python file (.py extension required)" | |
| # Execute the Python file | |
| result = subprocess.run( | |
| [sys.executable, file_path], | |
| capture_output=True, | |
| text=True, | |
| timeout=600 # 30 second timeout to prevent hanging | |
| ) | |
| # Prepare the output | |
| output_parts = [] | |
| if result.stdout: | |
| output_parts.append(f"STDOUT:\n{result.stdout}") | |
| if result.stderr: | |
| output_parts.append(f"STDERR:\n{result.stderr}") | |
| if result.returncode != 0: | |
| output_parts.append(f"Return code: {result.returncode}") | |
| if not output_parts: | |
| return "Script executed successfully with no output" | |
| return "\n\n".join(output_parts) | |
| except subprocess.TimeoutExpired: | |
| return "Error: Script execution timed out (30 seconds)" | |
| except Exception as e: | |
| return f"Error executing Python file: {str(e)}" | |
| async def _arun( | |
| self, | |
| file_path: str, | |
| run_manager: Optional[Any] = None, | |
| ) -> str: | |
| """Async version - delegates to sync implementation.""" | |
| return self._run(file_path, run_manager) | |
| class CommutativityAnalysisTool(BaseTool): | |
| """ | |
| A tool that analyzes an algebraic operation table to find elements | |
| involved in counter-examples to commutativity. | |
| This tool executes predefined Python code that: | |
| 1. Defines a set S = ['a', 'b', 'c', 'd', 'e'] | |
| 2. Defines an operation table as a dictionary of dictionaries | |
| 3. Finds elements where table[x][y] != table[y][x] (non-commutative pairs) | |
| 4. Returns a sorted, comma-separated list of all elements involved | |
| """ | |
| name: str = "commutativity_analysis" | |
| description: str = ( | |
| "Analyzes an algebraic operation table to find elements involved in " | |
| "counter-examples to commutativity. Returns a comma-separated list of " | |
| "elements where the operation is not commutative." | |
| "Provides a direct answer to question on commutativity analysis" | |
| ) | |
| return_direct: bool = False | |
| def _run( | |
| self, | |
| run_manager: Optional[CallbackManagerForToolRun] = None, | |
| **kwargs: Any | |
| ) -> str: | |
| """Execute the commutativity analysis synchronously.""" | |
| try: | |
| # Define the set and the operation table | |
| S = ['a', 'b', 'c', 'd', 'e'] | |
| # The operation table as a dictionary of dictionaries | |
| table = { | |
| 'a': {'a': 'a', 'b': 'b', 'c': 'c', 'd': 'b', 'e': 'd'}, | |
| 'b': {'a': 'b', 'b': 'c', 'c': 'a', 'd': 'e', 'e': 'c'}, | |
| 'c': {'a': 'c', 'b': 'a', 'c': 'b', 'd': 'b', 'e': 'a'}, | |
| 'd': {'a': 'b', 'b': 'e', 'c': 'b', 'd': 'e', 'e': 'd'}, | |
| 'e': {'a': 'd', 'b': 'b', 'c': 'a', 'd': 'd', 'e': 'c'} | |
| } | |
| # Find elements involved in counter-examples to commutativity | |
| involved = set() | |
| for x in S: | |
| for y in S: | |
| if table[x][y] != table[y][x]: | |
| involved.add(x) | |
| involved.add(y) | |
| # Output the result as a comma-separated, alphabetically sorted list | |
| result = "subset of S involved in any possible counter-examples: " + ', '.join(sorted(involved)) | |
| return result | |
| except Exception as e: | |
| return f"Error executing commutativity analysis: {str(e)}" | |
| async def _arun( | |
| self, | |
| run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
| **kwargs: Any | |
| ) -> str: | |
| """Execute the commutativity analysis asynchronously.""" | |
| # For this simple computation, we can just call the synchronous version | |
| return self._run() | |
| class EnhancedDuckDuckGoSearchTool(BaseTool): | |
| name: str = "enhanced_search" | |
| description: str = ( | |
| "Performs a DuckDuckGo web search and retrieves actual content from the top web results. " | |
| "Input should be a search query string. " | |
| "Returns search results with extracted content from web pages, making it much more useful for answering questions. " | |
| "Use this tool when you need up-to-date information, details about current events, or when other tools do not provide sufficient or recent answers. " | |
| "Ideal for topics that require the latest news, recent developments, or information not covered in static sources." | |
| ) | |
| max_results: int = 3 | |
| max_chars_per_page: int = 18000 | |
| session: Any = None | |
| def model_post_init(self, __context: Any) -> None: | |
| super().model_post_init(__context) | |
| self.session = requests.Session() | |
| self.session.headers.update({ | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', | |
| 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', | |
| 'Accept-Language': 'en-US,en;q=0.5', | |
| 'Accept-Encoding': 'gzip, deflate', | |
| 'Connection': 'keep-alive', | |
| 'Upgrade-Insecure-Requests': '1', | |
| }) | |
| def _search_duckduckgo(self, query_term: str) -> List[Dict]: # Renamed 'query' to 'query_term' for clarity | |
| """Perform DuckDuckGo search and return results.""" | |
| try: | |
| with DDGS() as ddgs: | |
| results = list(ddgs.text(query_term, max_results=self.max_results)) | |
| return results | |
| except Exception as e: | |
| logger.error(f"DuckDuckGo search failed: {e}") | |
| return [] | |
| def _extract_content_from_url(self, url: str, timeout: int = 10) -> Optional[str]: | |
| """Extract clean text content from a web page.""" | |
| try: | |
| if any(url.lower().endswith(ext) for ext in ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']): | |
| return "Content type not supported for extraction" | |
| response = self.session.get(url, timeout=timeout, allow_redirects=True) | |
| response.raise_for_status() | |
| content_type = response.headers.get('content-type', '').lower() | |
| if 'text/html' not in content_type: | |
| return "Non-HTML content detected" | |
| soup = BeautifulSoup(response.content, 'html.parser') | |
| for script_or_style in soup(["script", "style", "nav", "header", "footer", "aside", "form"]): | |
| script_or_style.decompose() | |
| main_content = None | |
| for selector in ['main', 'article', '.content', '#content', '.post', '.entry-content', '.entry']: # Added .entry-content | |
| main_content = soup.select_one(selector) | |
| if main_content: | |
| break | |
| if not main_content: | |
| main_content = soup.find('body') or soup | |
| text = main_content.get_text(separator='\n', strip=True) | |
| lines = [line.strip() for line in text.split('\n') if line.strip()] | |
| text = '\n'.join(lines) | |
| text = re.sub(r'\n{3,}', '\n\n', text) | |
| text = re.sub(r' {2,}', ' ', text) | |
| if len(text) > self.max_chars_per_page: | |
| text = text[:self.max_chars_per_page] + "\n[Content truncated...]" | |
| return text | |
| except requests.exceptions.Timeout: | |
| logger.warning(f"Page loading timed out for {url}") | |
| return "Page loading timed out" | |
| except requests.exceptions.RequestException as e: | |
| logger.warning(f"Failed to retrieve page {url}: {str(e)}") | |
| return f"Failed to retrieve page: {str(e)}" | |
| except Exception as e: | |
| logger.error(f"Content extraction failed for {url}: {e}") | |
| return "Failed to extract content from page" | |
| def _format_search_result(self, result: Dict, content: str) -> str: | |
| """Format a single search result with its content.""" | |
| title = result.get('title', 'No title') | |
| url = result.get('href', 'No URL') | |
| snippet = result.get('body', 'No snippet') | |
| return f"🔍 **{title}**\nURL: {url}\nSnippet: {snippet}\n\n📄 **Page Content:**\n{content}\n---\n" | |
| def run(self, tool_input: Union[str, Dict]) -> str: | |
| query_str: Optional[str] = None | |
| if isinstance(tool_input, dict): | |
| if "query" in tool_input and isinstance(tool_input["query"], str): | |
| query_str = tool_input["query"] | |
| elif "input" in tool_input and isinstance(tool_input["input"], str): | |
| query_str = tool_input["input"] | |
| else: | |
| return "Invalid input: Dictionary received, but does not contain a recognizable string query under 'query' or 'input' keys." | |
| elif isinstance(tool_input, str): | |
| query_str = tool_input | |
| else: | |
| return f"Invalid input type: Expected a string or a dictionary, but got {type(tool_input).__name__}." | |
| # The misplaced docstring """Execute the enhanced search.""" was removed from here. | |
| # Use query_str consistently from now on | |
| if not query_str or not query_str.strip(): | |
| return "Please provide a search query." | |
| query_str = query_str.strip() # Apply strip to query_str | |
| logger.info(f"Searching for: {query_str}") # Use query_str | |
| search_results = self._search_duckduckgo(query_str) # Use query_str | |
| if not search_results: | |
| return f"No search results found for query: {query_str}" # Use query_str | |
| enhanced_results = [] | |
| processed_count = 0 | |
| for i, result in enumerate(search_results[:self.max_results]): | |
| url = result.get('href', '') | |
| if not url: | |
| continue | |
| logger.info(f"Processing result {i+1}: {url}") | |
| content = self._extract_content_from_url(url) | |
| if content and len(content.strip()) > 50: | |
| formatted_result = self._format_search_result(result, content) | |
| enhanced_results.append(formatted_result) | |
| processed_count += 1 | |
| time.sleep(0.5) # Consider making this configurable or adjusting based on use case | |
| if not enhanced_results: | |
| return f"Search completed but no content could be extracted from the pages for query: {query_str}" # Use query_str | |
| response = f"""🔍 **Enhanced Search Results for: "{query_str}"** | |
| Found {len(search_results)} results, successfully processed {processed_count} pages with content. | |
| {''.join(enhanced_results)} | |
| 💡 **Summary:** Retrieved and processed content from {processed_count} web pages to provide comprehensive information about your search query. | |
| """ # Use query_str | |
| if len(response) > 18000: # This limit is arbitrary; consider if it should relate to self.max_chars_per_page | |
| response = response[:18000] + "\n[Response truncated to prevent memory issues]" | |
| return response | |
| def _run(self, query_or_tool_input: Union[str, Dict]) -> str: # Updated to reflect run's input | |
| """Required by BaseTool interface. Handles various input types.""" | |
| # This _run method now correctly passes the input to the run method, | |
| # which is designed to handle both string and dictionary inputs. | |
| return self.run(query_or_tool_input) | |
| # --- Agent State Definition --- | |
| class AgentState(TypedDict): | |
| messages: Annotated[List[AnyMessage], lambda x, y: x + y] | |
| done: bool = False # Default value of False | |
| question: str | |
| task_id: str | |
| input_file: Optional[bytes] | |
| file_type: Optional[str] | |
| context: List[Document] # Using LangChain's Document class | |
| file_path: Optional[str] | |
| youtube_url: Optional[str] | |
| answer: Optional[str] | |
| frame_answers: Optional[list] | |
| def fetch_page_with_tables(page_title): | |
| """ | |
| Fetches Wikipedia page content and extracts all tables as readable text. | |
| Returns a tuple: (main_text, [table_texts]) | |
| """ | |
| # Fetch the page object | |
| page = wikipedia.page(page_title) | |
| main_text = page.content | |
| # Get the HTML for table extraction | |
| html = page.html() | |
| soup = BeautifulSoup(html, 'html.parser') | |
| tables = soup.find_all('table') | |
| table_texts = [] | |
| for table in tables: | |
| rows = table.find_all('tr') | |
| table_lines = [] | |
| for row in rows: | |
| cells = row.find_all(['th', 'td']) | |
| cell_texts = [cell.get_text(strip=True) for cell in cells] | |
| if cell_texts: | |
| # Format as Markdown table row | |
| table_lines.append(" | ".join(cell_texts)) | |
| if table_lines: | |
| table_text = "\n".join(table_lines) | |
| table_texts.append(table_text) | |
| return main_text, table_texts | |
| class WikipediaSearchToolWithFAISS(BaseTool): | |
| name: str = "wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval" | |
| description: str = ( | |
| "Fetches content from multiple Wikipedia pages based on intelligent NLP query processing " | |
| "of various search candidates, with strong prioritization of query entities. It then performs " | |
| "entity-focused semantic search across all fetched content to find the most relevant information, " | |
| "with improved retrieval for lists like discographies. Uses spaCy for named entity " | |
| "recognition and query enhancement. Input should be a search query or topic. " | |
| "Note: Uses the current live version of Wikipedia." | |
| ) | |
| embedding_model_name: str = "all-MiniLM-L6-v2" | |
| chunk_size: int = 4000 | |
| chunk_overlap: int = 250 # Maintained moderate overlap | |
| top_k_results: int = 3 | |
| spacy_model: str = "en_core_web_sm" | |
| # Increased multiplier to fetch more candidates per semantic query variant | |
| semantic_search_candidate_multiplier: int = 1 # Was 2, increased to 3, consider 4 if still problematic | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| try: | |
| self._nlp = spacy.load(self.spacy_model) | |
| print(f"Loaded spaCy model: {self.spacy_model}") | |
| self._embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_name) | |
| # Refined separators for better handling of Wikipedia lists and sections | |
| self._text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=self.chunk_size, | |
| chunk_overlap=self.chunk_overlap, | |
| separators=[ | |
| "\n\n== ", "\n\n=== ", "\n\n==== ", # Section headers (keep with following content) | |
| "\n\n\n", "\n\n", # Multiple newlines (paragraph breaks) | |
| "\n* ", "\n- ", "\n# ", # List items | |
| "\n", ". ", "! ", "? ", # Sentence breaks after newline, common punctuation | |
| " ", "" # Word and character level | |
| ] | |
| ) | |
| except OSError as e: | |
| print(f"Error loading spaCy model '{self.spacy_model}': {e}") | |
| print("Try running: python -m spacy download en_core_web_sm") | |
| self._nlp = None | |
| self._embedding_model = None | |
| self._text_splitter = None | |
| except Exception as e: | |
| print(f"Error initializing WikipediaSearchToolWithFAISS components: {e}") | |
| self._nlp = None | |
| self._embedding_model = None | |
| self._text_splitter = None | |
| def _extract_entities_and_keywords(self, query: str) -> Tuple[List[str], List[str], str]: | |
| if not self._nlp: | |
| return [], [], query | |
| doc = self._nlp(query) | |
| main_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "EVENT", "WORK_OF_ART"]] | |
| keywords = [token.lemma_.lower() for token in doc if token.pos_ in ["NOUN", "PROPN", "ADJ"] and not token.is_stop and not token.is_punct and len(token.text) > 2] | |
| main_entities = list(dict.fromkeys(main_entities)) | |
| keywords = list(dict.fromkeys(keywords)) | |
| processed_tokens = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct and token.text.strip()] | |
| processed_query = " ".join(processed_tokens) | |
| return main_entities, keywords, processed_query | |
| def _generate_search_candidates(self, query: str, main_entities: List[str], keywords: List[str], processed_query: str) -> List[str]: | |
| candidates_set = set() | |
| entity_prefix = main_entities[0] if main_entities else None | |
| for me in main_entities: | |
| candidates_set.add(me) | |
| candidates_set.add(query) | |
| if processed_query and processed_query != query: | |
| candidates_set.add(processed_query) | |
| if entity_prefix and keywords: | |
| first_entity_lower = entity_prefix.lower() | |
| for kw in keywords[:3]: | |
| if kw not in first_entity_lower and len(kw) > 2: | |
| candidates_set.add(f"{entity_prefix} {kw}") | |
| keyword_combo_short = " ".join(k for k in keywords[:2] if k not in first_entity_lower and len(k)>2) | |
| if keyword_combo_short: candidates_set.add(f"{entity_prefix} {keyword_combo_short}") | |
| if len(main_entities) > 1: | |
| candidates_set.add(" ".join(main_entities[:2])) | |
| if keywords: | |
| keyword_combo = " ".join(keywords[:2]) | |
| if entity_prefix: | |
| candidate_to_add = f"{entity_prefix} {keyword_combo}" | |
| if not any(c.lower() == candidate_to_add.lower() for c in candidates_set): | |
| candidates_set.add(candidate_to_add) | |
| elif not main_entities: | |
| candidates_set.add(keyword_combo) | |
| ordered_candidates = [] | |
| for me in main_entities: | |
| if me not in ordered_candidates: ordered_candidates.append(me) | |
| for c in list(candidates_set): | |
| if c and c.strip() and c not in ordered_candidates: ordered_candidates.append(c) | |
| print(f"Generated {len(ordered_candidates)} search candidates for Wikipedia page lookup (entity-prioritized): {ordered_candidates}") | |
| return ordered_candidates | |
| def _smart_wikipedia_search(self, query_text: str, main_entities_from_query: List[str], keywords_from_query: List[str], processed_query_text: str) -> List[Tuple[str, str]]: | |
| candidates = self._generate_search_candidates(query_text, main_entities_from_query, keywords_from_query, processed_query_text) | |
| found_pages_data: List[Tuple[str, str]] = [] | |
| processed_page_titles: Set[str] = set() | |
| for i, candidate_query in enumerate(candidates): | |
| print(f"\nProcessing candidate {i+1}/{len(candidates)} for page: '{candidate_query}'") | |
| page_object = None | |
| final_page_title = None | |
| is_candidate_entity_focused = any(me.lower() in candidate_query.lower() for me in main_entities_from_query) if main_entities_from_query else False | |
| try: | |
| try: | |
| page_to_load = candidate_query | |
| suggest_mode = True # Default to auto_suggest=True | |
| if is_candidate_entity_focused and main_entities_from_query: | |
| try: # Attempt precise match first for entity-focused candidates | |
| temp_page = wikipedia.page(page_to_load, auto_suggest=False, redirect=True) | |
| suggest_mode = False # Flag that precise match worked | |
| except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError): | |
| print(f" - auto_suggest=False failed for entity-focused '{page_to_load}', trying with auto_suggest=True.") | |
| # Fallthrough to auto_suggest=True below if this fails | |
| if suggest_mode: # If not attempted or failed with auto_suggest=False | |
| temp_page = wikipedia.page(page_to_load, auto_suggest=True, redirect=True) | |
| final_page_title = temp_page.title | |
| if is_candidate_entity_focused and main_entities_from_query: | |
| title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
| if not title_matches_main_entity: | |
| print(f" ! Page title '{final_page_title}' (from entity-focused candidate '{candidate_query}') " | |
| f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
| continue | |
| if final_page_title in processed_page_titles: | |
| print(f" ~ Already processed '{final_page_title}'") | |
| continue | |
| page_object = temp_page | |
| print(f" ✓ Direct hit/suggestion for '{candidate_query}' -> '{final_page_title}'") | |
| except wikipedia.exceptions.PageError: | |
| if i < max(2, len(candidates) // 3) : # Try Wikipedia search for a smaller, more promising subset of candidates | |
| print(f" - Direct access failed for '{candidate_query}'. Trying Wikipedia search...") | |
| search_results = wikipedia.search(candidate_query, results=1) | |
| if not search_results: | |
| print(f" - No Wikipedia search results for '{candidate_query}'.") | |
| continue | |
| search_result_title = search_results[0] | |
| try: | |
| temp_page = wikipedia.page(search_result_title, auto_suggest=False, redirect=True) # Search results are usually canonical | |
| final_page_title = temp_page.title | |
| if is_candidate_entity_focused and main_entities_from_query: # Still check against original intent | |
| title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
| if not title_matches_main_entity: | |
| print(f" ! Page title '{final_page_title}' (from search for '{candidate_query}' -> '{search_result_title}') " | |
| f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
| continue | |
| if final_page_title in processed_page_titles: | |
| print(f" ~ Already processed '{final_page_title}'") | |
| continue | |
| page_object = temp_page | |
| print(f" ✓ Found via search '{candidate_query}' -> '{search_result_title}' -> '{final_page_title}'") | |
| except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError) as e_sr: | |
| print(f" ! Error/Disambiguation for search result '{search_result_title}': {e_sr}") | |
| else: | |
| print(f" - Direct access failed for '{candidate_query}'. Skipping further search for this lower priority candidate.") | |
| except wikipedia.exceptions.DisambiguationError as de: | |
| print(f" ! Disambiguation for '{candidate_query}'. Options: {de.options[:1]}") | |
| if de.options: | |
| option_title = de.options[0] | |
| try: | |
| temp_page = wikipedia.page(option_title, auto_suggest=False, redirect=True) | |
| final_page_title = temp_page.title | |
| if is_candidate_entity_focused and main_entities_from_query: # Check against original intent | |
| title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
| if not title_matches_main_entity: | |
| print(f" ! Page title '{final_page_title}' (from disamb. of '{candidate_query}' -> '{option_title}') " | |
| f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
| continue | |
| if final_page_title in processed_page_titles: | |
| print(f" ~ Already processed '{final_page_title}'") | |
| continue | |
| page_object = temp_page | |
| print(f" ✓ Resolved disambiguation '{candidate_query}' -> '{option_title}' -> '{final_page_title}'") | |
| except Exception as e_dis_opt: | |
| print(f" ! Could not load disambiguation option '{option_title}': {e_dis_opt}") | |
| if page_object and final_page_title and (final_page_title not in processed_page_titles): | |
| # Extract main text | |
| main_text = page_object.content | |
| # Extract tables using BeautifulSoup | |
| try: | |
| html = page_object.html() | |
| soup = BeautifulSoup(html, 'html.parser') | |
| tables = soup.find_all('table') | |
| table_texts = [] | |
| for table in tables: | |
| rows = table.find_all('tr') | |
| table_lines = [] | |
| for row in rows: | |
| cells = row.find_all(['th', 'td']) | |
| cell_texts = [cell.get_text(strip=True) for cell in cells] | |
| if cell_texts: | |
| table_lines.append(" | ".join(cell_texts)) | |
| if table_lines: | |
| table_text = "\n".join(table_lines) | |
| table_texts.append(table_text) | |
| except Exception as e: | |
| print(f" !! Error extracting tables for '{final_page_title}': {e}") | |
| table_texts = [] | |
| # Combine main text and all table texts as separate chunks | |
| all_text_chunks = [main_text] + table_texts | |
| for chunk in all_text_chunks: | |
| found_pages_data.append((chunk, final_page_title)) | |
| processed_page_titles.add(final_page_title) | |
| print(f" -> Added page '{final_page_title}'. Main text length: {len(main_text)} | Tables extracted: {len(table_texts)}") | |
| except Exception as e: | |
| print(f" !! Unexpected error processing candidate '{candidate_query}': {e}") | |
| if not found_pages_data: print(f"\nCould not find any new, unique, entity-validated Wikipedia pages for query '{query_text}'.") | |
| else: print(f"\nFound {len(found_pages_data)} unique, validated page(s) for processing.") | |
| return found_pages_data | |
| def _enhance_semantic_search(self, query: str, vector_store, main_entities: List[str], keywords: List[str], processed_query: str) -> List[Document]: | |
| core_query_parts = set() | |
| core_query_parts.add(query) | |
| if processed_query != query: core_query_parts.add(processed_query) | |
| if keywords: core_query_parts.add(" ".join(keywords[:2])) | |
| section_phrases_templates = [] | |
| lower_query_terms = set(query.lower().split()) | set(k.lower() for k in keywords) | |
| section_keywords_map = { | |
| "discography": ["discography", "list of studio albums", "studio album titles and years", "albums by year", "album release dates", "official albums", "complete album list", "albums published"], | |
| "biography": ["biography", "life story", "career details", "background history"], | |
| "filmography": ["filmography", "list of films", "movie appearances", "acting roles"], | |
| } | |
| for section_term_key, specific_phrases_list in section_keywords_map.items(): | |
| # Check if the key (e.g., "discography") or any of its specific phrases (e.g. "list of studio albums") | |
| # are mentioned or implied by the query terms. | |
| if section_term_key in lower_query_terms or any(phrase_part in lower_query_terms for phrase_part in section_term_key.split()): | |
| section_phrases_templates.extend(specific_phrases_list) | |
| # Also check if phrases themselves are in query terms (e.g. query "list of albums by X") | |
| for phrase in specific_phrases_list: | |
| if phrase in query.lower(): # Check against original query for direct phrase matches | |
| section_phrases_templates.extend(specific_phrases_list) # Add all related if one specific is hit | |
| break | |
| section_phrases_templates = list(dict.fromkeys(section_phrases_templates)) # Deduplicate | |
| final_search_queries = set() | |
| if main_entities: | |
| entity_prefix = main_entities[0] | |
| final_search_queries.add(entity_prefix) | |
| for part in core_query_parts: | |
| final_search_queries.add(f"{entity_prefix} {part}" if entity_prefix.lower() not in part.lower() else part) | |
| for phrase_template in section_phrases_templates: | |
| final_search_queries.add(f"{entity_prefix} {phrase_template}") | |
| if "list of" in phrase_template or "history of" in phrase_template : | |
| final_search_queries.add(f"{phrase_template} of {entity_prefix}") | |
| else: | |
| final_search_queries.update(core_query_parts) | |
| final_search_queries.update(section_phrases_templates) | |
| deduplicated_queries = list(dict.fromkeys(sq for sq in final_search_queries if sq and sq.strip())) | |
| print(f"Generated {len(deduplicated_queries)} semantic search query variants (list-retrieval focused): {deduplicated_queries}") | |
| all_results_docs: List[Document] = [] | |
| seen_content_hashes: Set[int] = set() | |
| k_to_fetch = self.top_k_results * self.semantic_search_candidate_multiplier | |
| for search_query_variant in deduplicated_queries: | |
| try: | |
| results = vector_store.similarity_search_with_score(search_query_variant, k=k_to_fetch) | |
| print(f" Semantic search variant '{search_query_variant}' (k={k_to_fetch}) -> {len(results)} raw chunk(s) with scores.") | |
| for doc, score in results: # Assuming similarity_search_with_score returns (doc, score) | |
| content_hash = hash(doc.page_content[:250]) # Slightly more for hash uniqueness | |
| if content_hash not in seen_content_hashes: | |
| seen_content_hashes.add(content_hash) | |
| doc.metadata['retrieved_by_variant'] = search_query_variant | |
| doc.metadata['retrieval_score'] = float(score) # Store score | |
| all_results_docs.append(doc) | |
| except Exception as e: | |
| print(f" Error in semantic search for variant '{search_query_variant}': {e}") | |
| # Sort all collected unique results by score (FAISS L2 distance is lower is better) | |
| all_results_docs.sort(key=lambda x: x.metadata.get('retrieval_score', float('inf'))) | |
| print(f"Collected and re-sorted {len(all_results_docs)} unique chunks from all semantic query variants.") | |
| return all_results_docs[:self.top_k_results] | |
| def _run(self, query: str = None, search_query: str = None, **kwargs) -> str: | |
| if not self._nlp or not self._embedding_model or not self._text_splitter: | |
| print("ERROR: WikipediaSearchToolWithFAISS components not initialized properly.") | |
| return "Error: Wikipedia tool components not initialized properly. Please check server logs." | |
| if not query: | |
| query = search_query or kwargs.get('q') or kwargs.get('search_term') | |
| try: | |
| print(f"\n--- Running {self.name} for query: '{query}' ---") | |
| main_entities, keywords, processed_query = self._extract_entities_and_keywords(query) | |
| print(f"Initial NLP Analysis - Main Entities: {main_entities}, Keywords: {keywords}, Processed Query: '{processed_query}'") | |
| fetched_pages_data = self._smart_wikipedia_search(query, main_entities, keywords, processed_query) | |
| if not fetched_pages_data: | |
| return (f"Could not find any relevant, entity-validated Wikipedia pages for the query '{query}'. " | |
| f"Main entities sought: {main_entities}") | |
| all_page_titles = [title for _, title in fetched_pages_data] | |
| print(f"\nSuccessfully fetched content for {len(fetched_pages_data)} Wikipedia page(s): {', '.join(all_page_titles)}") | |
| all_documents: List[Document] = [] | |
| for page_content, page_title in fetched_pages_data: | |
| chunks = self._text_splitter.split_text(page_content) | |
| if not chunks: | |
| print(f"Warning: Could not split content from Wikipedia page '{page_title}' into chunks.") | |
| continue | |
| for i, chunk_text in enumerate(chunks): | |
| all_documents.append(Document(page_content=chunk_text, metadata={ | |
| "source_page_title": page_title, | |
| "original_query": query, | |
| "chunk_index": i # Add chunk index for potential debugging or ordering | |
| })) | |
| print(f"Split content from '{page_title}' into {len(chunks)} chunks.") | |
| if not all_documents: | |
| return (f"Could not process content into searchable chunks from the fetched Wikipedia pages " | |
| f"({', '.join(all_page_titles)}) for query '{query}'.") | |
| print(f"\nTotal document chunks from all pages: {len(all_documents)}") | |
| print("Creating FAISS index from content of all fetched pages...") | |
| try: | |
| vector_store = FAISS.from_documents(all_documents, self._embedding_model) | |
| print("FAISS index created successfully.") | |
| except Exception as e: | |
| return f"Error creating FAISS vector store: {e}" | |
| print(f"\nPerforming enhanced semantic search across all collected content...") | |
| try: | |
| relevant_docs = self._enhance_semantic_search(query, vector_store, main_entities, keywords, processed_query) | |
| except Exception as e: | |
| return f"Error during semantic search: {e}" | |
| if not relevant_docs: | |
| return (f"No relevant information found within Wikipedia page(s) '{', '.join(list(dict.fromkeys(all_page_titles)))}' " | |
| f"for your query '{query}' using entity-focused semantic search with list retrieval.") | |
| unique_sources_in_results = list(dict.fromkeys([doc.metadata.get('source_page_title', 'Unknown Source') for doc in relevant_docs])) | |
| result_header = (f"Found {len(relevant_docs)} relevant piece(s) of information from Wikipedia page(s) " | |
| f"'{', '.join(unique_sources_in_results)}' for your query '{query}':\n") | |
| nlp_summary = (f"[Original Query NLP: Main Entities: {', '.join(main_entities) if main_entities else 'None'}, " | |
| f"Keywords: {', '.join(keywords[:5]) if keywords else 'None'}]\n\n") | |
| result_details = [] | |
| for i, doc in enumerate(relevant_docs): | |
| source_info = doc.metadata.get('source_page_title', 'Unknown Source') | |
| variant_info = doc.metadata.get('retrieved_by_variant', 'N/A') | |
| score_info = doc.metadata.get('retrieval_score', 'N/A') | |
| detail = (f"Result {i+1} (source: '{source_info}', score: {score_info:.4f})\n" | |
| f"(Retrieved by: '{variant_info}')\n{doc.page_content}") | |
| result_details.append(detail) | |
| final_result = result_header + nlp_summary + "\n\n---\n\n".join(result_details) | |
| print(f"\nReturning {len(relevant_docs)} relevant chunks from {len(set(all_page_titles))} source page(s).") | |
| return final_result.strip() | |
| except Exception as e: | |
| import traceback | |
| print(f"Unexpected error in {self.name}: {traceback.format_exc()}") | |
| return f"An unexpected error occurred: {str(e)}" | |
| class EnhancedYoutubeScreenshotQA(BaseTool): | |
| name: str = "bird_species_screenshot_qa" | |
| description: str = ( | |
| "Use this tool to calculate the number of bird species on camera at any one time," | |
| "Input should be a dict with keys: 'youtube_url', 'question', and optional parameters. " | |
| "Example: {'youtube_url': 'https://youtube.com/watch?v=xyz', 'question': 'What animals are visible?'}" | |
| ) | |
| # Define Pydantic fields for the attributes we need to set | |
| device: Any = Field(default=None, exclude=True) | |
| processor_vqa: Any = Field(default=None, exclude=True) | |
| model_vqa: Any = Field(default=None, exclude=True) | |
| class Config: | |
| # Allow arbitrary types (needed for torch.device, model objects) | |
| arbitrary_types_allowed = True | |
| # Allow extra fields to be set | |
| extra = "allow" | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| # Initialize directories | |
| cache_dir = '/tmp/youtube_qa_cache/' | |
| video_dir = '/tmp/video/' | |
| frames_dir = '/tmp/video_frames/' | |
| # Initialize model and device | |
| self._initialize_model() | |
| # Create directories | |
| for dir_path in [cache_dir, video_dir, frames_dir]: | |
| os.makedirs(dir_path, exist_ok=True) | |
| def _get_config(self, key: str, default_value=None, input_data: Dict[str, Any] = None): | |
| """Get configuration value with fallback to defaults""" | |
| defaults = { | |
| 'frame_interval_seconds': 20, | |
| 'max_frames': 50, | |
| 'use_scene_detection': True, | |
| 'resize_frames': True, | |
| 'parallel_processing': True, | |
| 'cache_enabled': True, | |
| 'quality_threshold': 30.0, | |
| 'semantic_similarity_threshold': 0.8 | |
| } | |
| if input_data and key in input_data: | |
| return input_data[key] | |
| return defaults.get(key, default_value) | |
| def _initialize_model(self): | |
| """Initialize BLIP model for VQA with error handling""" | |
| try: | |
| self.device = torch.device("cpu") | |
| print(f"Using device: {self.device}") | |
| self.processor_vqa = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base") | |
| self.model_vqa = BlipForQuestionAnswering.from_pretrained( | |
| "Salesforce/blip-vqa-base" | |
| ).to(self.device) | |
| print("BLIP VQA model loaded successfully") | |
| except Exception as e: | |
| print(f"Error initializing VQA model: {str(e)}") | |
| raise | |
| def _get_video_hash(self, url: str) -> str: | |
| """Generate hash for video URL for caching""" | |
| return hashlib.md5(url.encode()).hexdigest() | |
| def _get_cache_path(self, video_hash: str, cache_type: str) -> str: | |
| """Get cache file path""" | |
| cache_dir = '/tmp/youtube_qa_cache/' | |
| return os.path.join(cache_dir, f"{video_hash}_{cache_type}") | |
| def _load_from_cache(self, cache_path: str, cache_enabled: bool = True) -> Optional[Any]: | |
| """Load data from cache""" | |
| if not cache_enabled or not os.path.exists(cache_path): | |
| return None | |
| try: | |
| with open(cache_path, 'r') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| print(f"Error loading cache: {str(e)}") | |
| return None | |
| def _save_to_cache(self, cache_path: str, data: Any, cache_enabled: bool = True): | |
| """Save data to cache""" | |
| if not cache_enabled: | |
| return | |
| try: | |
| with open(cache_path, 'w') as f: | |
| json.dump(data, f) | |
| except Exception as e: | |
| print(f"Error saving cache: {str(e)}") | |
| def download_youtube_video(self, url: str, video_hash: str, cache_enabled: bool = True) -> Optional[str]: | |
| """Enhanced YouTube video download with anti-bot measures""" | |
| video_dir = '/tmp/video/' | |
| output_filename = f'{video_hash}.mp4' | |
| output_path = os.path.join(video_dir, output_filename) | |
| # Check cache | |
| if cache_enabled and os.path.exists(output_path): | |
| print(f"Using cached video: {output_path}") | |
| return output_path | |
| # Clean directory | |
| self._clean_directory(video_dir) | |
| try: | |
| # Enhanced yt-dlp options with anti-bot measures | |
| ydl_opts = { | |
| # Format selection - prefer lower quality to avoid restrictions | |
| 'format': 'best[height<=480][ext=mp4]/best[height<=720][ext=mp4]/best[ext=mp4]/best', | |
| 'outtmpl': output_path, | |
| 'quiet': False, # Changed to False for debugging | |
| 'no_warnings': False, | |
| 'merge_output_format': 'mp4', | |
| # Anti-bot headers and user agent | |
| 'http_headers': { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36', | |
| 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8', | |
| 'Accept-Language': 'en-us,en;q=0.5', | |
| 'Accept-Encoding': 'gzip,deflate', | |
| 'Accept-Charset': 'ISO-8859-1,utf-8;q=0.7,*;q=0.7', | |
| 'Connection': 'keep-alive', | |
| 'Upgrade-Insecure-Requests': '1', | |
| }, | |
| # Additional anti-detection measures | |
| 'extractor_args': { | |
| 'youtube': { | |
| 'skip': ['hls', 'dash'], # Skip some formats that might trigger detection | |
| 'player_skip': ['js'], # Skip JavaScript player | |
| } | |
| }, | |
| # Rate limiting | |
| 'sleep_interval': 1, | |
| 'max_sleep_interval': 5, | |
| 'sleep_interval_subtitles': 1, | |
| # Retry settings | |
| 'retries': 3, | |
| 'fragment_retries': 3, | |
| 'skip_unavailable_fragments': True, | |
| # Cookie handling (you can add browser cookies if needed) | |
| # 'cookiefile': '/path/to/cookies.txt', # Uncomment and set path if you have cookies | |
| # Additional options | |
| 'extract_flat': False, | |
| 'writesubtitles': False, | |
| 'writeautomaticsub': False, | |
| 'ignoreerrors': True, | |
| # Postprocessors | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegVideoConvertor', | |
| 'preferedformat': 'mp4', | |
| }] | |
| } | |
| print(f"Attempting to download: {url}") | |
| # Try multiple download strategies | |
| strategies = [ | |
| # Strategy 1: Standard download | |
| ydl_opts, | |
| # Strategy 2: More conservative approach | |
| { | |
| **ydl_opts, | |
| 'format': 'worst[ext=mp4]/worst', # Try worst quality first | |
| 'sleep_interval': 2, | |
| 'max_sleep_interval': 10, | |
| }, | |
| # Strategy 3: Different user agent | |
| { | |
| **ydl_opts, | |
| 'http_headers': { | |
| 'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/16.1 Safari/605.1.15' | |
| }, | |
| 'format': 'best[height<=360][ext=mp4]/best[ext=mp4]/best', | |
| } | |
| ] | |
| last_error = None | |
| for i, strategy in enumerate(strategies, 1): | |
| try: | |
| print(f"Trying download strategy {i}/3...") | |
| with yt_dlp.YoutubeDL(strategy) as ydl: | |
| # Add some delay before download | |
| import time | |
| time.sleep(2) | |
| ydl.download([url]) | |
| if os.path.exists(output_path): | |
| print(f"Video downloaded successfully with strategy {i}: {output_path}") | |
| return output_path | |
| else: | |
| print(f"Strategy {i} completed but file not found") | |
| except Exception as e: | |
| last_error = e | |
| print(f"Strategy {i} failed: {str(e)}") | |
| if i < len(strategies): | |
| print(f"Trying next strategy...") | |
| # Add delay between strategies | |
| import time | |
| time.sleep(5) | |
| continue | |
| # If all strategies failed, try one more approach with cookies from browser | |
| print("All standard strategies failed. Trying with browser cookies...") | |
| try: | |
| cookie_strategy = { | |
| **ydl_opts, | |
| 'cookiesfrombrowser': ('chrome',), # Try to get cookies from Chrome | |
| 'format': 'worst[ext=mp4]/worst', | |
| } | |
| with yt_dlp.YoutubeDL(cookie_strategy) as ydl: | |
| ydl.download([url]) | |
| if os.path.exists(output_path): | |
| print(f"Video downloaded successfully with browser cookies: {output_path}") | |
| return output_path | |
| except Exception as e: | |
| print(f"Browser cookie strategy also failed: {str(e)}") | |
| print(f"All download strategies failed. Last error: {last_error}") | |
| return None | |
| except Exception as e: | |
| print(f"Error downloading YouTube video: {str(e)}") | |
| return None | |
| def _clean_directory(self, directory: str): | |
| """Clean directory contents""" | |
| if os.path.exists(directory): | |
| for filename in os.listdir(directory): | |
| file_path = os.path.join(directory, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| print(f'Failed to delete {file_path}. Reason: {e}') | |
| def _assess_frame_quality(self, frame: np.ndarray) -> float: | |
| """Assess frame quality using Laplacian variance (blur detection)""" | |
| try: | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| return cv2.Laplacian(gray, cv2.CV_64F).var() | |
| except Exception: | |
| return 0.0 | |
| def _detect_scene_changes(self, video_path: str, threshold: float = 30.0) -> List[int]: | |
| """Detect scene changes in video""" | |
| scene_frames = [] | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return [] | |
| prev_frame = None | |
| frame_count = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if prev_frame is not None: | |
| # Calculate histogram difference | |
| hist1 = cv2.calcHist([prev_frame], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) | |
| hist2 = cv2.calcHist([frame], [0, 1, 2], None, [8, 8, 8], [0, 256, 0, 256, 0, 256]) | |
| diff = cv2.compareHist(hist1, hist2, cv2.HISTCMP_CHISQR) | |
| if diff > threshold: | |
| scene_frames.append(frame_count) | |
| prev_frame = frame.copy() | |
| frame_count += 1 | |
| cap.release() | |
| return scene_frames | |
| except Exception as e: | |
| print(f"Error in scene detection: {str(e)}") | |
| return [] | |
| def smart_extract_frames(self, video_path: str, video_hash: str, input_data: Dict[str, Any] = None) -> List[str]: | |
| """Intelligently extract frames with quality filtering and scene detection""" | |
| cache_enabled = self._get_config('cache_enabled', True, input_data) | |
| cache_path = self._get_cache_path(video_hash, "frames_info.json") | |
| cached_info = self._load_from_cache(cache_path, cache_enabled) | |
| if cached_info: | |
| # Verify cached frames still exist | |
| existing_frames = [f for f in cached_info['frame_paths'] if os.path.exists(f)] | |
| if len(existing_frames) == len(cached_info['frame_paths']): | |
| print(f"Using {len(existing_frames)} cached frames") | |
| return existing_frames | |
| # Clean frames directory | |
| frames_dir = '/tmp/video_frames/' | |
| self._clean_directory(frames_dir) | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| print("Error: Could not open video.") | |
| return [] | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| frame_interval_seconds = self._get_config('frame_interval_seconds', 10, input_data) | |
| frame_interval = max(1, int(fps * frame_interval_seconds)) | |
| print(f"Video info: {total_frames} frames, {fps:.2f} fps") | |
| # Get scene change frames if enabled | |
| scene_frames = set() | |
| use_scene_detection = self._get_config('use_scene_detection', True, input_data) | |
| if use_scene_detection: | |
| scene_frames = set(self._detect_scene_changes(video_path)) | |
| print(f"Detected {len(scene_frames)} scene changes") | |
| extracted_frames = [] | |
| frame_count = 0 | |
| saved_count = 0 | |
| max_frames = self._get_config('max_frames', 50, input_data) | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret or saved_count >= max_frames: | |
| break | |
| # Check if we should extract this frame | |
| should_extract = ( | |
| frame_count % frame_interval == 0 or | |
| frame_count in scene_frames | |
| ) | |
| if should_extract: | |
| # Assess frame quality | |
| quality = self._assess_frame_quality(frame) | |
| quality_threshold = self._get_config('quality_threshold', 30.0, input_data) | |
| if quality >= quality_threshold: | |
| # Resize frame if enabled | |
| resize_frames = self._get_config('resize_frames', True, input_data) | |
| if resize_frames: | |
| height, width = frame.shape[:2] | |
| if width > 800: | |
| scale = 800 / width | |
| new_width = 800 | |
| new_height = int(height * scale) | |
| frame = cv2.resize(frame, (new_width, new_height)) | |
| frame_filename = os.path.join( | |
| frames_dir, | |
| f"frame_{frame_count:06d}_q{quality:.1f}.jpg" | |
| ) | |
| if cv2.imwrite(frame_filename, frame): | |
| extracted_frames.append(frame_filename) | |
| saved_count += 1 | |
| print(f"Extracted frame {saved_count}/{max_frames} " | |
| f"(quality: {quality:.1f})") | |
| frame_count += 1 | |
| cap.release() | |
| # Cache frame information | |
| frame_info = { | |
| 'frame_paths': extracted_frames, | |
| 'extraction_time': time.time(), | |
| 'total_frames_processed': frame_count, | |
| 'frames_extracted': len(extracted_frames) | |
| } | |
| self._save_to_cache(cache_path, frame_info, cache_enabled) | |
| print(f"Successfully extracted {len(extracted_frames)} high-quality frames") | |
| return extracted_frames | |
| except Exception as e: | |
| print(f"Exception during frame extraction: {e}") | |
| return [] | |
| def _answer_question_on_frame(self, frame_path: str, question: str) -> Tuple[str, float]: | |
| """Answer question on single frame with confidence scoring""" | |
| try: | |
| image = Image.open(frame_path).convert('RGB') | |
| inputs = self.processor_vqa(image, question, return_tensors="pt").to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model_vqa.generate(**inputs, output_scores=True, return_dict_in_generate=True) | |
| answer = self.processor_vqa.decode(outputs.sequences[0], skip_special_tokens=True) | |
| # Calculate confidence (simplified - you might want to use actual model confidence) | |
| confidence = 1.0 # Placeholder - BLIP doesn't directly provide confidence | |
| return answer, confidence | |
| except Exception as e: | |
| print(f"Error processing frame {frame_path}: {str(e)}") | |
| return "Error processing this frame", 0.0 | |
| def _process_frames_parallel(self, frame_files: List[str], question: str, input_data: Dict[str, Any] = None) -> List[Tuple[str, str, float]]: | |
| """Process frames in parallel""" | |
| results = [] | |
| parallel_processing = self._get_config('parallel_processing', True, input_data) | |
| if parallel_processing: | |
| with ThreadPoolExecutor(max_workers=min(4, len(frame_files))) as executor: | |
| future_to_frame = { | |
| executor.submit(self._answer_question_on_frame, frame_path, question): frame_path | |
| for frame_path in frame_files | |
| } | |
| for future in as_completed(future_to_frame): | |
| frame_path = future_to_frame[future] | |
| try: | |
| answer, confidence = future.result() | |
| results.append((frame_path, answer, confidence)) | |
| print(f"Processed {os.path.basename(frame_path)}: {answer} (conf: {confidence:.2f})") | |
| except Exception as e: | |
| print(f"Error processing {frame_path}: {str(e)}") | |
| results.append((frame_path, "Error", 0.0)) | |
| else: | |
| for frame_path in frame_files: | |
| answer, confidence = self._answer_question_on_frame(frame_path, question) | |
| results.append((frame_path, answer, confidence)) | |
| print(f"Processed {os.path.basename(frame_path)}: {answer} (conf: {confidence:.2f})") | |
| return results | |
| def _cluster_similar_answers(self, answers: List[str], input_data: Dict[str, Any] = None) -> Dict[str, List[str]]: | |
| """Cluster semantically similar answers""" | |
| if len(answers) <= 1: | |
| return {answers[0]: answers} if answers else {} | |
| try: | |
| # First try with standard TF-IDF settings | |
| vectorizer = TfidfVectorizer( | |
| stop_words='english', | |
| lowercase=True, | |
| min_df=1, # Include words that appear in at least 1 document | |
| max_df=1.0 # Include words that appear in up to 100% of documents | |
| ) | |
| tfidf_matrix = vectorizer.fit_transform(answers) | |
| # Check if we have any features after TF-IDF | |
| if tfidf_matrix.shape[1] == 0: | |
| raise ValueError("No features after TF-IDF processing") | |
| # Calculate cosine similarity | |
| similarity_matrix = cosine_similarity(tfidf_matrix) | |
| # Cluster similar answers | |
| clusters = defaultdict(list) | |
| used = set() | |
| semantic_similarity_threshold = self._get_config('semantic_similarity_threshold', 0.8, input_data) | |
| for i, answer in enumerate(answers): | |
| if i in used: | |
| continue | |
| cluster_key = answer | |
| clusters[cluster_key].append(answer) | |
| used.add(i) | |
| # Find similar answers | |
| for j in range(i + 1, len(answers)): | |
| if j not in used and similarity_matrix[i][j] >= semantic_similarity_threshold: | |
| clusters[cluster_key].append(answers[j]) | |
| used.add(j) | |
| return dict(clusters) | |
| except (ValueError, Exception) as e: | |
| print(f"Error in semantic clustering: {str(e)}") | |
| # Fallback 1: Try without stop words filtering | |
| try: | |
| print("Attempting clustering without stop word filtering...") | |
| vectorizer_no_stop = TfidfVectorizer( | |
| lowercase=True, | |
| min_df=1, | |
| token_pattern=r'\b\w+\b' # Match any word | |
| ) | |
| tfidf_matrix = vectorizer_no_stop.fit_transform(answers) | |
| if tfidf_matrix.shape[1] > 0: | |
| similarity_matrix = cosine_similarity(tfidf_matrix) | |
| clusters = defaultdict(list) | |
| used = set() | |
| semantic_similarity_threshold = self._get_config('semantic_similarity_threshold', 0.8, input_data) | |
| for i, answer in enumerate(answers): | |
| if i in used: | |
| continue | |
| cluster_key = answer | |
| clusters[cluster_key].append(answer) | |
| used.add(i) | |
| for j in range(i + 1, len(answers)): | |
| if j not in used and similarity_matrix[i][j] >= semantic_similarity_threshold: | |
| clusters[cluster_key].append(answers[j]) | |
| used.add(j) | |
| return dict(clusters) | |
| except Exception as e2: | |
| print(f"Fallback clustering also failed: {str(e2)}") | |
| # Fallback 2: Simple string-based clustering | |
| print("Using simple string-based clustering...") | |
| return self._simple_string_cluster(answers) | |
| def _simple_string_cluster(self, answers: List[str]) -> Dict[str, List[str]]: | |
| """Simple string-based clustering fallback""" | |
| clusters = defaultdict(list) | |
| # Normalize answers for comparison | |
| normalized_answers = {} | |
| for answer in answers: | |
| normalized = answer.lower().strip() | |
| normalized_answers[answer] = normalized | |
| used = set() | |
| for i, answer in enumerate(answers): | |
| if answer in used: | |
| continue | |
| cluster_key = answer | |
| clusters[cluster_key].append(answer) | |
| used.add(answer) | |
| # Find similar answers using simple string similarity | |
| for j, other_answer in enumerate(answers[i+1:], i+1): | |
| if other_answer in used: | |
| continue | |
| # Check for exact match after normalization | |
| if normalized_answers[answer] == normalized_answers[other_answer]: | |
| clusters[cluster_key].append(other_answer) | |
| used.add(other_answer) | |
| # Alternatively, check if one string contains the other | |
| elif (normalized_answers[answer] in normalized_answers[other_answer] or | |
| normalized_answers[other_answer] in normalized_answers[answer]): | |
| clusters[cluster_key].append(other_answer) | |
| used.add(other_answer) | |
| return dict(clusters) | |
| def _analyze_temporal_patterns(self, results: List[Tuple[str, str, float]]) -> Dict[str, Any]: | |
| """Analyze temporal patterns in answers""" | |
| try: | |
| # Sort by frame number | |
| def get_frame_number(frame_path): | |
| match = re.search(r'frame_(\d+)', os.path.basename(frame_path)) | |
| return int(match.group(1)) if match else 0 | |
| sorted_results = sorted(results, key=lambda x: get_frame_number(x[0])) | |
| # Analyze answer changes over time | |
| answers_timeline = [result[1] for result in sorted_results] | |
| changes = [] | |
| for i in range(1, len(answers_timeline)): | |
| if answers_timeline[i] != answers_timeline[i-1]: | |
| changes.append({ | |
| 'frame_index': i, | |
| 'from_answer': answers_timeline[i-1], | |
| 'to_answer': answers_timeline[i] | |
| }) | |
| return { | |
| 'total_changes': len(changes), | |
| 'change_points': changes, | |
| 'stability_ratio': 1 - (len(changes) / max(1, len(answers_timeline) - 1)), | |
| 'answers_timeline': answers_timeline | |
| } | |
| except Exception as e: | |
| print(f"Error in temporal analysis: {str(e)}") | |
| return {'error': str(e)} | |
| def analyze_video_question(self, frame_files: List[str], question: str, input_data: Dict[str, Any] = None) -> Dict[str, Any]: | |
| """Comprehensive video question analysis""" | |
| if not frame_files: | |
| return { | |
| "final_answer": "No frames available for analysis.", | |
| "confidence": 0.0, | |
| "frame_count": 0, | |
| "error": "No valid frames found" | |
| } | |
| # Process all frames | |
| print(f"Processing {len(frame_files)} frames...") | |
| results = self._process_frames_parallel(frame_files, question, input_data) | |
| if not results: | |
| return { | |
| "final_answer": "Could not analyze any frames successfully.", | |
| "confidence": 0.0, | |
| "frame_count": 0, | |
| "error": "Frame processing failed" | |
| } | |
| # Extract answers and confidences | |
| answers = [result[1] for result in results if result[1] != "Error"] | |
| confidences = [result[2] for result in results if result[1] != "Error"] | |
| # Calculate statistical summary on numeric answers | |
| numeric_answers = [] | |
| for answer in answers: | |
| try: | |
| # Try to convert answer to float | |
| numeric_value = float(answer) | |
| numeric_answers.append(numeric_value) | |
| except (ValueError, TypeError): | |
| # Skip non-numeric answers | |
| pass | |
| if numeric_answers: | |
| stats = { | |
| "minimum": float(np.min(numeric_answers)), | |
| "maximum": float(np.max(numeric_answers)), | |
| "range": float(np.max(numeric_answers) - np.min(numeric_answers)), | |
| "mean": float(np.mean(numeric_answers)), | |
| "median": float(np.median(numeric_answers)), | |
| "count": len(numeric_answers), | |
| "data_type": "answers" | |
| } | |
| elif confidences: | |
| # Fallback to confidence statistics if no numeric answers | |
| stats = { | |
| "minimum": float(np.min(confidences)), | |
| "maximum": float(np.max(confidences)), | |
| "range": float(np.max(confidences) - np.min(confidences)), | |
| "mean": float(np.mean(confidences)), | |
| "median": float(np.median(confidences)), | |
| "count": len(confidences), | |
| "data_type": "confidences" | |
| } | |
| else: | |
| stats = { | |
| "minimum": 0.0, | |
| "maximum": 0.0, | |
| "range": 0.0, | |
| "mean": 0.0, | |
| "median": 0.0, | |
| "count": 0, | |
| "data_type": "none", | |
| "note": "No numeric results available for statistical summary" | |
| } | |
| if not answers: | |
| return { | |
| "final_answer": "All frame processing failed.", | |
| "confidence": 0.0, | |
| "frame_count": len(frame_files), | |
| "error": "No successful frame analysis" | |
| } | |
| # Cluster similar answers | |
| answer_clusters = self._cluster_similar_answers(answers, input_data) | |
| # Find most common cluster | |
| largest_cluster = max(answer_clusters.items(), key=lambda x: len(x[1])) | |
| most_common_answer = largest_cluster[0] | |
| # Calculate weighted confidence | |
| answer_counts = Counter(answers) | |
| total_answers = len(answers) | |
| frequency_confidence = answer_counts[most_common_answer] / total_answers | |
| avg_confidence = np.mean(confidences) if confidences else 0.0 | |
| final_confidence = (frequency_confidence * 0.7) + (avg_confidence * 0.3) | |
| # Temporal analysis | |
| temporal_analysis = self._analyze_temporal_patterns(results) | |
| return { | |
| "final_answer": most_common_answer, | |
| "confidence": final_confidence, | |
| "frame_count": len(frame_files), | |
| "successful_analyses": len(answers), | |
| "answer_distribution": dict(answer_counts), | |
| "semantic_clusters": {k: len(v) for k, v in answer_clusters.items()}, | |
| "temporal_analysis": temporal_analysis, | |
| "average_model_confidence": avg_confidence, | |
| "frequency_confidence": frequency_confidence, | |
| "statistical_summary": stats | |
| } | |
| def _run(self, youtube_url, question, **kwargs) -> str: | |
| """Enhanced main execution method""" | |
| question = "How many unique bird species are on camera?" | |
| input_data = { | |
| 'youtube_url': youtube_url, | |
| 'question': question | |
| } | |
| if not youtube_url or not question: | |
| return "Error: Input must include 'youtube_url' and 'question'." | |
| try: | |
| # Generate video hash for caching | |
| video_hash = self._get_video_hash(youtube_url) | |
| # Step 1: Download video | |
| print(f"Downloading YouTube video from {youtube_url}...") | |
| cache_enabled = self._get_config('cache_enabled', True, input_data) | |
| video_path = self.download_youtube_video(youtube_url, video_hash, cache_enabled) | |
| if not video_path or not os.path.exists(video_path): | |
| return "Error: Failed to download the YouTube video. This may be due to YouTube's anti-bot protection. Try using a different video or implement cookie authentication." | |
| # Step 2: Smart frame extraction | |
| print(f"Extracting frames with smart selection...") | |
| frame_files = self.smart_extract_frames(video_path, video_hash, input_data) | |
| if not frame_files: | |
| return "Error: Failed to extract frames from the video." | |
| # Step 3: Comprehensive analysis | |
| print(f"Analyzing {len(frame_files)} frames for question: '{question}'") | |
| analysis_result = self.analyze_video_question(frame_files, question, input_data) | |
| if analysis_result.get("error"): | |
| return f"Error: {analysis_result['error']}" | |
| # Format comprehensive result - Fixed the reference to stats | |
| result = f""" | |
| 📊 **STATISTICAL SUMMARY**: | |
| • Minimum: {analysis_result['statistical_summary']['minimum']:.2f} | |
| • Maximum: {analysis_result['statistical_summary']['maximum']:.2f} | |
| • Mean: {analysis_result['statistical_summary']['mean']:.2f} | |
| • Median: {analysis_result['statistical_summary']['median']:.2f} | |
| • Range: {analysis_result['statistical_summary']['range']:.2f} | |
| """.strip() | |
| return result | |
| except Exception as e: | |
| return f"Error during video analysis: {str(e)}" | |
| # Initialize the enhanced tool | |
| def create_enhanced_youtube_qa_tool(**kwargs): | |
| """Factory function to create the enhanced tool with custom parameters""" | |
| return EnhancedYoutubeScreenshotQA(**kwargs) | |
| import os | |
| import json | |
| import hashlib | |
| import time | |
| import shutil | |
| import glob | |
| from typing import Dict, Any, List, Optional | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| import yt_dlp | |
| import speech_recognition as sr | |
| from pydantic import Field | |
| from pydantic.v1 import BaseModel | |
| from pydub import AudioSegment | |
| from pydub.silence import split_on_silence | |
| class BaseTool(BaseModel): | |
| name: str | |
| description: str | |
| class YouTubeTranscriptExtractor(BaseTool): | |
| name: str = "youtube_transcript_extractor" | |
| description: str = ( | |
| "Downloads a YouTube video and extracts the complete audio transcript using speech recognition. " | |
| "Use this tool for questions about what people say in YouTube videos. " | |
| "Input should be a dict with keys: 'youtube_url' and optional parameters. " | |
| "Optional parameters include 'language' (e.g., 'en-US'), " | |
| "'cookies_file_path' (path to a cookies TXT file for authentication), " | |
| "or 'cookies_from_browser' (string specifying browser for cookies, e.g., 'chrome', 'firefox:profileName', 'edge+keyringName:profileName::containerName'). " | |
| "Example: {'youtube_url': 'https://youtube.com/watch?v=xyz', 'language': 'en-US'} or " | |
| "{'youtube_url': '...', 'cookies_file_path': '/path/to/cookies.txt'} or " | |
| "{'youtube_url': '...', 'cookies_from_browser': 'chrome'}" | |
| ) | |
| recognizer: Any = Field(default=None, exclude=True) | |
| class Config: | |
| arbitrary_types_allowed = True | |
| extra = Extra.allow # Adjusted if pydantic v1 style | |
| def __init__(self, **kwargs: Any): | |
| super().__init__(**kwargs) | |
| self.cache_dir = '/tmp/youtube_transcript_cache/' | |
| self.audio_dir = '/tmp/audio/' | |
| self.chunks_dir = '/tmp/audio_chunks/' | |
| self.recognizer = sr.Recognizer() | |
| self.recognizer.energy_threshold = 4000 | |
| self.recognizer.pause_threshold = 0.8 | |
| for dir_path in [self.cache_dir, self.audio_dir, self.chunks_dir]: | |
| os.makedirs(dir_path, exist_ok=True) | |
| def _get_config(self, key: str, default_value: Any = None, input_data: Optional[Dict[str, Any]] = None) -> Any: | |
| defaults = { | |
| 'language': 'en-US', | |
| 'chunk_length_ms': 30000, | |
| 'silence_thresh': -40, | |
| 'audio_quality': 'best', | |
| 'cache_enabled': True, | |
| 'min_silence_len': 500, | |
| 'overlap_ms': 1000, | |
| 'cookies_file_path': None, # New: Path to a cookies file | |
| 'cookies_from_browser': None # New: Browser string e.g., "chrome", "firefox:profile_name" | |
| } | |
| if input_data and key in input_data: | |
| return input_data[key] | |
| return defaults.get(key, default_value) | |
| def _get_video_hash(self, url: str) -> str: | |
| return hashlib.md5(url.encode()).hexdigest() | |
| def _get_cache_path(self, video_hash: str, cache_type: str) -> str: | |
| return os.path.join(self.cache_dir, f"{video_hash}_{cache_type}") | |
| def _load_from_cache(self, cache_path: str, cache_enabled: bool = True) -> Optional[Any]: | |
| if not cache_enabled or not os.path.exists(cache_path): | |
| return None | |
| try: | |
| with open(cache_path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| except Exception as e: | |
| print(f"Error loading cache: {str(e)}") | |
| return None | |
| def _save_to_cache(self, cache_path: str, data: Any, cache_enabled: bool = True): | |
| if not cache_enabled: | |
| return | |
| try: | |
| with open(cache_path, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, ensure_ascii=False, indent=2) | |
| except Exception as e: | |
| print(f"Error saving cache: {str(e)}") | |
| def _clean_directory(self, directory: str): | |
| if os.path.exists(directory): | |
| for filename in os.listdir(directory): | |
| file_path = os.path.join(directory, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| print(f'Failed to delete {file_path}. Reason: {e}') | |
| def download_youtube_audio(self, url: str, video_hash: str, input_data: Optional[Dict[str, Any]] = None) -> Optional[str]: | |
| audio_quality = self._get_config('audio_quality', 'best', input_data) | |
| output_filename = f'{video_hash}.wav' | |
| output_path = os.path.join(self.audio_dir, output_filename) | |
| cache_enabled = self._get_config('cache_enabled', True, input_data) | |
| if cache_enabled and os.path.exists(output_path): | |
| print(f"Using cached audio: {output_path}") | |
| return output_path | |
| self._clean_directory(self.audio_dir) | |
| cookies_file_path = self._get_config('cookies_file_path', None, input_data) | |
| cookies_from_browser_str = self._get_config('cookies_from_browser', None, input_data) | |
| try: | |
| ydl_opts: Dict[str, Any] = { | |
| 'format': 'bestaudio[ext=m4a]/bestaudio/best', | |
| 'outtmpl': os.path.join(self.audio_dir, f'{video_hash}.%(ext)s'), | |
| 'quiet': False, | |
| 'no_warnings': False, | |
| 'extract_flat': False, # Ensure this is false for actual downloads | |
| 'writethumbnail': False, | |
| 'writeinfojson': False, | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegExtractAudio', | |
| 'preferredcodec': 'wav', | |
| 'preferredquality': '192' if audio_quality == 'best' else '128', | |
| }], | |
| 'http_headers': { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36' | |
| }, | |
| 'nocheckcertificate': True, | |
| } | |
| if cookies_file_path: | |
| ydl_opts['cookiefile'] = cookies_file_path | |
| print(f"Using cookies from file: {cookies_file_path}") | |
| elif cookies_from_browser_str: | |
| parsed_browser, parsed_profile, parsed_keyring, parsed_container = None, None, None, None | |
| temp_str = cookies_from_browser_str | |
| if '::' in temp_str: | |
| main_part_before_container, parsed_container_val = temp_str.split('::', 1) | |
| parsed_container = parsed_container_val if parsed_container_val else None | |
| temp_str = main_part_before_container | |
| if ':' in temp_str: | |
| browser_keyring_part, parsed_profile_val = temp_str.split(':', 1) | |
| parsed_profile = parsed_profile_val if parsed_profile_val else None | |
| temp_str = browser_keyring_part | |
| if '+' in temp_str: | |
| parsed_browser_val, parsed_keyring_val = temp_str.split('+', 1) | |
| parsed_browser = parsed_browser_val | |
| parsed_keyring = parsed_keyring_val if parsed_keyring_val else None | |
| else: | |
| parsed_browser = temp_str | |
| if parsed_browser: | |
| # yt-dlp expects cookiesfrombrowser as a tuple: (BROWSER, PROFILE, KEYRING, CONTAINER) | |
| final_tuple: Tuple[Optional[str], ...] = ( | |
| parsed_browser, | |
| parsed_profile, | |
| parsed_keyring, | |
| parsed_container | |
| ) | |
| ydl_opts['cookiesfrombrowser'] = final_tuple | |
| print(f"Attempting to use cookies from browser spec '{cookies_from_browser_str}', parsed as: {final_tuple}") | |
| else: | |
| print(f"Invalid or empty browser name in cookies_from_browser string: '{cookies_from_browser_str}'") | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| print(f"Downloading audio from: {url} with options: {ydl_opts}") | |
| ydl.download([url]) | |
| if os.path.exists(output_path): | |
| print(f"Audio downloaded successfully: {output_path}") | |
| return output_path | |
| else: | |
| possible_files = glob.glob(os.path.join(self.audio_dir, f'{video_hash}.*')) | |
| if possible_files: | |
| source_file = possible_files[0] | |
| if not source_file.endswith('.wav'): | |
| try: | |
| audio = AudioSegment.from_file(source_file) | |
| audio.export(output_path, format="wav") | |
| os.remove(source_file) | |
| print(f"Audio converted to WAV: {output_path}") | |
| return output_path | |
| except Exception as e: | |
| print(f"Error converting audio: {str(e)}") | |
| return None | |
| else: # Already a .wav, possibly due to postprocessor already creating it with a different ext pattern | |
| if source_file != output_path: # if names differ due to original extension | |
| shutil.move(source_file, output_path) | |
| print(f"Audio file found: {output_path}") | |
| return output_path | |
| print(f"No audio file found at expected path after download: {output_path}") | |
| return None | |
| except yt_dlp.utils.DownloadError as de: | |
| print(f"yt-dlp DownloadError: {str(de)}") | |
| if "Sign in to confirm you're not a bot" in str(de) and not (cookies_file_path or cookies_from_browser_str): | |
| print("Authentication required. Consider using 'cookies_file_path' or 'cookies_from_browser' options.") | |
| return None | |
| except Exception as e: | |
| print(f"Error downloading YouTube audio: {type(e).__name__} - {str(e)}") | |
| # Fallback attempt is removed as it's unlikely to succeed if the primary authenticated attempt fails due to bot detection | |
| return None | |
| def _split_audio_intelligent(self, audio_path: str, input_data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: | |
| self._clean_directory(self.chunks_dir) | |
| try: | |
| audio = AudioSegment.from_wav(audio_path) | |
| chunk_length_ms = self._get_config('chunk_length_ms', 30000, input_data) | |
| silence_thresh = self._get_config('silence_thresh', -40, input_data) | |
| min_silence_len = self._get_config('min_silence_len', 500, input_data) | |
| overlap_ms = self._get_config('overlap_ms', 1000, input_data) # Not used in current split_on_silence | |
| chunks = split_on_silence( | |
| audio, | |
| min_silence_len=min_silence_len, | |
| silence_thresh=silence_thresh, | |
| keep_silence=True | |
| ) | |
| processed_chunks: List[AudioSegment] = [] # type: ignore | |
| # Combine small chunks or re-chunk if silence splitting is ineffective | |
| temp_chunk: Optional[AudioSegment] = None # type: ignore | |
| for chunk in chunks: | |
| if temp_chunk is None: | |
| temp_chunk = chunk | |
| else: | |
| temp_chunk += chunk | |
| if len(temp_chunk) > chunk_length_ms / 2 or chunk == chunks[-1]: # Arbitrary threshold to combine small chunks | |
| processed_chunks.append(temp_chunk) | |
| temp_chunk = None | |
| if not processed_chunks or any(len(p_chunk) > chunk_length_ms * 1.5 for p_chunk in processed_chunks): # If still problematic | |
| print("Using time-based splitting due to ineffective silence splitting or overly large chunks...") | |
| processed_chunks = [] | |
| for i in range(0, len(audio), chunk_length_ms - overlap_ms): | |
| chunk_segment = audio[i:i + chunk_length_ms] | |
| if len(chunk_segment) > 1000: | |
| processed_chunks.append(chunk_segment) | |
| chunk_data = [] | |
| current_time_ms = 0 | |
| for i, chunk_segment in enumerate(processed_chunks): | |
| if len(chunk_segment) < 1000: continue | |
| chunk_filename = os.path.join(self.chunks_dir, f"chunk_{i:04d}.wav") | |
| chunk_segment.export(chunk_filename, format="wav") | |
| duration_s = len(chunk_segment) / 1000.0 | |
| start_time_s = current_time_ms / 1000.0 | |
| end_time_s = start_time_s + duration_s | |
| chunk_data.append({ | |
| 'filename': chunk_filename, 'index': i, | |
| 'start_time': start_time_s, 'duration': duration_s, 'end_time': end_time_s | |
| }) | |
| current_time_ms += len(chunk_segment) # Approximation, true timestamping is harder | |
| print(f"Split audio into {len(chunk_data)} chunks") | |
| return chunk_data | |
| except Exception as e: | |
| print(f"Error splitting audio: {str(e)}") | |
| try: # Fallback: single chunk | |
| audio = AudioSegment.from_wav(audio_path) | |
| duration = len(audio) / 1000.0 | |
| return [{'filename': audio_path, 'index': 0, 'start_time': 0, 'duration': duration, 'end_time': duration}] | |
| except: return [] | |
| def _transcribe_audio_chunk(self, chunk_info: Dict[str, Any], input_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: | |
| chunk_path = chunk_info['filename'] | |
| base_result = { | |
| 'start_time': chunk_info.get('start_time', 0), 'end_time': chunk_info.get('end_time', 0), | |
| 'duration': chunk_info.get('duration', 0), 'index': chunk_info.get('index', -1), | |
| 'success': False, 'confidence': 0.0 | |
| } | |
| try: | |
| language = self._get_config('language', 'en-US', input_data) | |
| with sr.AudioFile(chunk_path) as source: | |
| self.recognizer.adjust_for_ambient_noise(source, duration=0.2) # Shorter adjustment | |
| audio_data = self.recognizer.record(source) | |
| try: | |
| text = self.recognizer.recognize_google(audio_data, language=language) | |
| return {**base_result, 'text': text, 'confidence': 1.0, 'success': True} | |
| except sr.UnknownValueError: | |
| try: # Try without specific language | |
| text = self.recognizer.recognize_google(audio_data) | |
| return {**base_result, 'text': text, 'confidence': 0.8, 'success': True} # Lower confidence | |
| except sr.UnknownValueError: | |
| return {**base_result, 'text': '[INAUDIBLE]'} | |
| except sr.RequestError as e: | |
| return {**base_result, 'text': f'[RECOGNITION_ERROR: {str(e)}]', 'error': str(e)} | |
| except Exception as e: | |
| return {**base_result, 'text': f'[ERROR: {str(e)}]', 'error': str(e)} | |
| def _transcribe_chunks_parallel(self, chunk_data: List[Dict[str, Any]], input_data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: | |
| results = [] | |
| max_workers = min(os.cpu_count() or 1, 4) # Limit workers | |
| with ThreadPoolExecutor(max_workers=max_workers) as executor: | |
| future_to_chunk = { | |
| executor.submit(self._transcribe_audio_chunk, chunk_info, input_data): chunk_info | |
| for chunk_info in chunk_data | |
| } | |
| for future in as_completed(future_to_chunk): | |
| chunk_info = future_to_chunk[future] | |
| try: | |
| result = future.result() | |
| results.append(result) | |
| status = "Transcribed" if result['success'] else "Failed" | |
| preview = result['text'][:50] + "..." if len(result['text']) > 50 else result['text'] | |
| print(f"{status} chunk {result['index']}: {preview}") | |
| except Exception as e: | |
| print(f"Error processing chunk {chunk_info.get('index', '?')}: {str(e)}") | |
| results.append({ | |
| 'text': f'[PROCESSING_ERROR: {str(e)}]', 'confidence': 0.0, | |
| 'start_time': chunk_info.get('start_time', 0), 'end_time': chunk_info.get('end_time', 0), | |
| 'duration': chunk_info.get('duration', 0), 'index': chunk_info.get('index', 0), | |
| 'success': False, 'error': str(e) | |
| }) | |
| results.sort(key=lambda x: x['index']) | |
| return results | |
| def extract_transcript(self, audio_path: str, video_hash: str, input_data: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: | |
| cache_enabled = self._get_config('cache_enabled', True, input_data) | |
| cache_path = self._get_cache_path(video_hash, "transcript.json") | |
| cached_transcript = self._load_from_cache(cache_path, cache_enabled) | |
| if cached_transcript: | |
| print("Using cached transcript") | |
| return cached_transcript | |
| try: | |
| print("Splitting audio into chunks...") | |
| chunk_data = self._split_audio_intelligent(audio_path, input_data) | |
| if not chunk_data: | |
| return {'error': 'Failed to split audio', 'full_transcript': '', 'success_rate': 0.0} | |
| print(f"Transcribing {len(chunk_data)} audio chunks...") | |
| transcript_results = self._transcribe_chunks_parallel(chunk_data, input_data) | |
| successful_chunks = [r for r in transcript_results if r['success']] | |
| full_text = ' '.join([r['text'] for r in successful_chunks if r['text'] and '[INAUDIBLE]' not in r['text'] and 'ERROR' not in r['text']]).strip() | |
| total_c = len(transcript_results) | |
| successful_c = len(successful_chunks) | |
| success_rate = successful_c / total_c if total_c > 0 else 0.0 | |
| final_result = { | |
| 'full_transcript': full_text, 'word_count': len(full_text.split()), | |
| 'total_chunks': total_c, 'successful_chunks': successful_c, 'success_rate': success_rate, | |
| 'extraction_timestamp': time.time(), 'extraction_date': time.strftime('%Y-%m-%d %H:%M:%S'), | |
| 'detailed_results': transcript_results | |
| } | |
| self._save_to_cache(cache_path, final_result, cache_enabled) | |
| print(f"Transcript extraction completed. Success rate: {success_rate:.1%}") | |
| return final_result | |
| except Exception as e: | |
| print(f"Error during transcript extraction: {str(e)}") | |
| return {'error': str(e), 'full_transcript': '', 'success_rate': 0.0} | |
| def _run(self, youtube_url: str, **kwargs: Any) -> str: | |
| input_data = {'youtube_url': youtube_url, **kwargs} | |
| if not youtube_url: return "Error: youtube_url is required." | |
| try: | |
| video_hash = self._get_video_hash(youtube_url) | |
| print(f"Processing YouTube URL: {youtube_url} (Hash: {video_hash})") | |
| audio_path = self.download_youtube_audio(youtube_url, video_hash, input_data) | |
| if not audio_path or not os.path.exists(audio_path): | |
| return "Error: Failed to download YouTube audio. Check URL or authentication (cookies)." | |
| print("Extracting audio transcript...") | |
| transcript_result = self.extract_transcript(audio_path, video_hash, input_data) | |
| if transcript_result.get("error"): return f"Error: {transcript_result['error']}" | |
| main_transcript = transcript_result.get('full_transcript', '') | |
| if not main_transcript: return "Error: No transcript could be extracted." | |
| print(f"Transcript extracted. Word count: {transcript_result.get('word_count',0)}. Success: {transcript_result.get('success_rate',0):.1%}") | |
| return "TRANSCRIPT: " + main_transcript | |
| except Exception as e: | |
| print(f"Unhandled error in _run: {str(e)}") # For debugging | |
| return f"Error during transcript extraction: {str(e)}" | |
| # Factory function to create the tool | |
| def create_youtube_transcript_tool(**kwargs): | |
| """Factory function to create the transcript extraction tool with custom parameters""" | |
| return YouTubeTranscriptExtractor(**kwargs) | |
| # --- Model Configuration --- | |
| def create_llm_pipeline(): | |
| #model_id = "meta-llama/Llama-2-13b-chat-hf" | |
| #model_id = "meta-llama/Llama-3.3-70B-Instruct" | |
| #model_id = "mistralai/Mistral-Small-24B-Base-2501" | |
| model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
| #model_id = "Meta-Llama/Llama-2-7b-chat-hf" | |
| #model_id = "NousResearch/Nous-Hermes-2-Mistral-7B-DPO" | |
| #model_id = "TheBloke/Mistral-7B-Instruct-v0.1-GGUF" | |
| #model_id = "mistralai/Mistral-7B-Instruct-v0.2" | |
| #model_id = "Qwen/Qwen2-7B-Instruct" | |
| #model_id = "GSAI-ML/LLaDA-8B-Instruct" | |
| return pipeline( | |
| "text-generation", | |
| model=model_id, | |
| device_map="cpu", | |
| torch_dtype=torch.float16, | |
| max_new_tokens=1024, | |
| temperature=0.3, | |
| top_k=50, | |
| top_p=0.95 | |
| ) | |
| nlp = None # Set to None if not using spaCy, so the regex fallback is used in extract_entities | |
| # --- Agent State Definition --- | |
| class AgentState(TypedDict): | |
| messages: Annotated[List[AnyMessage], lambda x, y: x + y] | |
| done: bool = False # Default value of False | |
| question: str | |
| task_id: str | |
| input_file: Optional[bytes] | |
| file_type: Optional[str] | |
| context: List[Document] # Using LangChain's Document class | |
| file_path: Optional[str] | |
| youtube_url: Optional[str] | |
| answer: Optional[str] | |
| frame_answers: Optional[list] | |
| # --- Define Call LLM function --- | |
| # 3. Improved LLM call with memory management | |
| def call_llm_with_memory_management(state: AgentState, llm_model) -> AgentState: | |
| """Enhanced LLM call with better prompt engineering and hallucination prevention.""" | |
| print("Running call_llm with memory management...") | |
| #ipdb.set_trace() | |
| original_messages = messages_for_llm = state["messages"] | |
| # Context management - be more aggressive about truncation | |
| system_message_content = None | |
| if messages_for_llm and isinstance(messages_for_llm[0], SystemMessage): | |
| system_message_content = messages_for_llm[0] | |
| regular_messages = messages_for_llm[1:] | |
| else: | |
| regular_messages = messages_for_llm | |
| # Keep only the most recent messages (more aggressive) | |
| max_regular_messages = 6 # Reduced from 9 | |
| if len(regular_messages) > max_regular_messages: | |
| print(f"🔄 Truncating to {max_regular_messages} recent messages") | |
| regular_messages = regular_messages[-max_regular_messages:] | |
| # Reconstruct for LLM | |
| messages_for_llm = [] | |
| if system_message_content: | |
| messages_for_llm.append(system_message_content) | |
| messages_for_llm.extend(regular_messages) | |
| # Character limit check | |
| total_chars = sum(len(str(msg.content)) for msg in messages_for_llm) | |
| char_limit = 20000 | |
| if total_chars > char_limit: | |
| print(f"📏 Context too long ({total_chars} chars) - further truncation") | |
| while regular_messages and sum(len(str(m.content)) for m in regular_messages) > char_limit - (len(str(system_message_content.content)) if system_message_content else 0): | |
| regular_messages.pop(0) | |
| messages_for_llm = [] | |
| if system_message_content: | |
| messages_for_llm.append(system_message_content) | |
| messages_for_llm.extend(regular_messages) | |
| new_state = state.copy() | |
| try: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"🤖 Calling LLM with {len(messages_for_llm)} messages") | |
| # Convert to simple string format that the model can understand | |
| if len(messages_for_llm) == 2 and isinstance(messages_for_llm[0], SystemMessage) and isinstance(messages_for_llm[1], HumanMessage): | |
| # Initial query - use simple format | |
| system_content = messages_for_llm[0].content | |
| human_content = messages_for_llm[1].content | |
| formatted_input = f"{system_content}\n\nHuman: {human_content}\n\nAssistant:" | |
| else: | |
| # Ongoing conversation - build context | |
| formatted_input = "" | |
| # Add system message if present | |
| if system_message_content: | |
| formatted_input += f"{system_message_content.content}\n\n" | |
| # Add conversation messages | |
| for msg in regular_messages: | |
| if isinstance(msg, HumanMessage): | |
| formatted_input += f"Human: {msg.content}\n\n" | |
| elif isinstance(msg, AIMessage): | |
| formatted_input += f"Assistant: {msg.content}\n\n" | |
| elif isinstance(msg, ToolMessage): | |
| formatted_input += f"Tool Result: {msg.content}\n\n" | |
| # Add explicit instruction for immediate final answer if we have recent tool results | |
| if any(isinstance(msg, ToolMessage) for msg in regular_messages[-2:]): | |
| formatted_input += "Based on the tool results above, provide your FINAL ANSWER now.\n\n" | |
| formatted_input += "REMINDER ON ANSWER FORMAT: \n" | |
| formatted_input += "- Numbers: no commas, no units unless specified\n" | |
| formatted_input += "- Strings: no articles, no abbreviations, digits in plain text\n" | |
| formatted_input += "- Lists: comma-separated following above rules\n" | |
| formatted_input += "- Be extremely brief and concise" | |
| formatted_input += "Assistant:" | |
| print(f"Input preview: {formatted_input[:300]}...") | |
| llm_response_object = llm_model.invoke(formatted_input) | |
| # Process response and clean up hallucinated content | |
| if isinstance(llm_response_object, BaseMessage): | |
| raw_content = llm_response_object.content | |
| elif hasattr(llm_response_object, 'content'): | |
| raw_content = str(llm_response_object.content) | |
| else: | |
| raw_content = str(llm_response_object) | |
| # Clean up the response to prevent hallucinated follow-up questions | |
| cleaned_content = clean_llm_response(raw_content) | |
| ai_message_response = AIMessage(content=cleaned_content) | |
| print(f"🔍 LLM Response preview: {cleaned_content[:200]}...") | |
| final_messages = original_messages + [ai_message_response] | |
| new_state["messages"] = final_messages | |
| new_state.pop("done", None) | |
| except Exception as e: | |
| print(f"❌ LLM call failed: {e}") | |
| error_message = AIMessage(content=f"Error: LLM call failed - {str(e)}") | |
| new_state["messages"] = original_messages + [error_message] | |
| new_state["done"] = True | |
| finally: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| return new_state | |
| def clean_llm_response(response_text: str) -> str: | |
| """ | |
| Clean LLM response to prevent hallucinated follow-up questions and conversations. | |
| Specifically handles ReAct format: Thought: -> Action: -> Action Input: | |
| """ | |
| if not response_text: | |
| return response_text | |
| print(f"Initial response: {response_text[:200]}...") | |
| # --- START MODIFICATION --- | |
| # Isolate the text generated by the assistant in the last turn. | |
| # This prevents parsing examples or instructions from the preamble. | |
| assistant_marker = "Assistant:" | |
| last_marker_idx = response_text.rfind(assistant_marker) | |
| text_to_process = response_text # Default to full text if marker not found | |
| if last_marker_idx != -1: | |
| # If "Assistant:" is found, process only the text after the last occurrence. | |
| text_to_process = response_text[last_marker_idx + len(assistant_marker):].strip() | |
| print(f"ℹ️ Parsing content after last 'Assistant:': {text_to_process[:200]}...") | |
| else: | |
| # If "Assistant:" is not found, process the whole input. | |
| # This might occur if the input is already just the assistant's direct response | |
| # or if the prompt structure is different. | |
| print(f"ℹ️ No 'Assistant:' marker found. Processing entire input as is.") | |
| # --- END MODIFICATION --- | |
| # Now, all subsequent operations use 'text_to_process' | |
| # Try to find a complete ReAct pattern in the assistant's actual output | |
| react_pattern = r'Thought:\s*(.*?)\s*Action:\s*([^\n\r]+)\s*Action Input:\s*(.*?)(?=\s*(?:Thought:|Action:|FINAL ANSWER:|$))' | |
| # Apply search to 'text_to_process' | |
| react_match = re.search(react_pattern, text_to_process, re.DOTALL | re.IGNORECASE) | |
| if react_match: | |
| thought_text = react_match.group(1).strip() | |
| action_name = react_match.group(2).strip() | |
| action_input = react_match.group(3).strip() | |
| # Clean up the action input - remove any trailing content that looks like instructions | |
| action_input_clean = re.sub(r'\s*(When you have|FINAL ANSWER|ANSWER FORMAT|IMPORTANT:).*$', '', action_input, flags=re.DOTALL | re.IGNORECASE) | |
| action_input_clean = action_input_clean.strip() | |
| react_sequence = f"Thought: {thought_text}\nAction: {action_name}\nAction Input: {action_input_clean}" | |
| print(f"🔧 Found ReAct pattern - Action: {action_name}, Input: {action_input_clean[:100]}...") | |
| # Check if there's a FINAL ANSWER after the action input (this would be hallucination) | |
| # Check in the remaining part of 'text_to_process' | |
| remaining_text_in_process = text_to_process[react_match.end():] | |
| final_answer_after = re.search(r'FINAL ANSWER:', remaining_text_in_process, re.IGNORECASE) | |
| if final_answer_after: | |
| print(f"🚫 Removed hallucinated FINAL ANSWER after tool call") | |
| return react_sequence | |
| # If no ReAct pattern in 'text_to_process', check for standalone FINAL ANSWER | |
| # This variable will hold the text being processed for FINAL ANSWER and then for fallback. | |
| current_text_for_processing = text_to_process | |
| final_answer_match = re.search(r"FINAL ANSWER:\s*(.+?)(?=\n|$)", current_text_for_processing, re.IGNORECASE) | |
| if final_answer_match: | |
| answer_content = final_answer_match.group(1).strip() | |
| template_phrases = [ | |
| '[concise answer only]', | |
| '[concise answer - number/word/list only]', | |
| '[brief answer]', | |
| '[your answer here]', | |
| 'concise answer only', | |
| 'brief answer', | |
| 'your answer here' | |
| ] | |
| if any(phrase.lower() in answer_content.lower() for phrase in template_phrases): | |
| print(f"🚫 Ignoring template FINAL ANSWER: {answer_content}") | |
| # Remove the template FINAL ANSWER and continue cleaning on the remainder of 'current_text_for_processing' | |
| current_text_for_processing = current_text_for_processing[:final_answer_match.start()].strip() | |
| # Fall through to the general cleanup section below | |
| else: | |
| # Keep everything from the start of 'current_text_for_processing' up to and including the real FINAL ANSWER line only | |
| cleaned_output = current_text_for_processing[:final_answer_match.end()] | |
| # Check if there's additional content after FINAL ANSWER in 'current_text_for_processing' | |
| remaining_after_final_answer = current_text_for_processing[final_answer_match.end():].strip() | |
| if remaining_after_final_answer: | |
| print(f"🚫 Removed content after FINAL ANSWER: {remaining_after_final_answer[:100]}...") | |
| return cleaned_output.strip() | |
| # If no ReAct, or FINAL ANSWER was a template or not found, apply fallback cleaning to 'current_text_for_processing' | |
| lines = current_text_for_processing.split('\n') | |
| cleaned_lines = [] | |
| for i, line in enumerate(lines): | |
| # Stop if we see the model repeating system instructions | |
| if re.search(r'\[SYSTEM\]|\[HUMAN\]|\[ASSISTANT\]|\[TOOL\]', line, re.IGNORECASE): | |
| print(f"🚫 Stopped at repeated system format: {line}") | |
| break | |
| # Stop if we see the model generating format instructions | |
| if re.search(r'CRITICAL INSTRUCTIONS|FORMAT for tool use|ANSWER FORMAT', line, re.IGNORECASE): | |
| print(f"🚫 Stopped at repeated instructions: {line}") | |
| break | |
| # Stop if we see the model role-playing as a human asking questions | |
| if re.search(r'(what are|what is|how many|can you tell me)', line, re.IGNORECASE) and not line.strip().startswith(('Thought:', 'Action:', 'Action Input:')): | |
| # Make sure this isn't part of a legitimate thought process | |
| if i > 0 and not any(keyword in lines[i-1] for keyword in ['Thought:', 'Action:', 'need to']): | |
| print(f"🚫 Stopped at hallucinated question: {line}") | |
| break | |
| cleaned_lines.append(line) | |
| cleaned = '\n'.join(cleaned_lines).strip() | |
| print(f"Final cleaned response (fallback): {cleaned[:200]}...") | |
| return cleaned | |
| def parse_react_output(state: AgentState) -> AgentState: | |
| """ | |
| Enhanced parsing with better FINAL ANSWER detection and flow control. | |
| """ | |
| print("Running parse_react_output...") | |
| #ipdb.set_trace() | |
| messages = state.get("messages", []) | |
| if not messages: | |
| print("No messages in state.") | |
| new_state = state.copy() | |
| new_state["done"] = True | |
| return new_state | |
| # DEBUG | |
| print(f"parse_react_output: Entry message count: {len(messages)}") | |
| if messages and hasattr(messages[-1], 'tool_calls'): | |
| print(f"parse_react_output: Number of tool calls in last AIMessage: {len(messages[-1].tool_calls)}") | |
| last_message = messages[-1] | |
| new_state = state.copy() | |
| if not isinstance(last_message, AIMessage): | |
| print("Last message is not an AIMessage instance.") | |
| return new_state | |
| content = last_message.content | |
| if not isinstance(content, str): | |
| content = str(content) | |
| # Look for FINAL ANSWER first - this should take absolute priority | |
| # Use a more precise regex to capture just the answer line | |
| final_answer_match = re.search(r"FINAL ANSWER:\s*([^\n\r]+)", content, re.IGNORECASE) | |
| if final_answer_match: | |
| final_answer_text = final_answer_match.group(1).strip() | |
| # Check if this is template text (not a real answer) | |
| template_phrases = [ | |
| '[concise answer only]', | |
| '[concise answer - number/word/list only]', | |
| '[brief answer]', | |
| '[your answer here]', | |
| 'concise answer only', | |
| 'brief answer', | |
| 'your answer here' | |
| ] | |
| # If it's template text, don't treat it as a final answer | |
| if any(phrase.lower() in final_answer_text.lower() for phrase in template_phrases): | |
| print(f"🚫 Ignoring template FINAL ANSWER: '{final_answer_text}'") | |
| # Continue processing as if no final answer was found | |
| else: | |
| print(f"🎯 FINAL ANSWER found: '{final_answer_text}' - ENDING") | |
| # Store the answer in state for easy access | |
| new_state["answer"] = final_answer_text | |
| # Clean up the message content to just show the final answer | |
| clean_content = f"FINAL ANSWER: {final_answer_text}" | |
| updated_ai_message = AIMessage(content=clean_content, tool_calls=[]) | |
| new_state["messages"] = messages[:-1] + [updated_ai_message] | |
| new_state["done"] = True | |
| return new_state | |
| # If no FINAL ANSWER, look for tool calls | |
| action_match = re.search(r"Action:\s*([^\n]+)", content, re.IGNORECASE) | |
| action_input_match = re.search(r"Action Input:\s*(.+)", content, re.IGNORECASE | re.DOTALL) | |
| if action_match and action_input_match: | |
| tool_name = action_match.group(1).strip() | |
| tool_input_raw = action_input_match.group(1).strip() | |
| if tool_name.lower() == "none": | |
| print("Action is 'None' - treating as regular response") | |
| updated_ai_message = AIMessage(content=content, tool_calls=[]) | |
| new_state["messages"] = messages[:-1] + [updated_ai_message] | |
| new_state.pop("done", None) | |
| return new_state | |
| print(f"🔧 Tool call: {tool_name} with input: {tool_input_raw[:100]}...") | |
| # Parse tool arguments | |
| tool_args = {} | |
| try: | |
| trimmed_input = tool_input_raw.strip() | |
| if (trimmed_input.startswith('{') and trimmed_input.endswith('}')) or \ | |
| (trimmed_input.startswith('[') and trimmed_input.endswith(']')): | |
| tool_args = ast.literal_eval(trimmed_input) | |
| if not isinstance(tool_args, dict): | |
| tool_args = {"query": tool_input_raw} | |
| else: | |
| tool_args = {"query": tool_input_raw} | |
| except (ValueError, SyntaxError): | |
| tool_args = {"query": tool_input_raw} | |
| tool_call_id = str(uuid.uuid4()) | |
| parsed_tool_calls = [{"name": tool_name, "args": tool_args, "id": tool_call_id}] | |
| updated_ai_message = AIMessage(content=content, tool_calls=parsed_tool_calls) | |
| new_state["messages"] = messages[:-1] + [updated_ai_message] | |
| new_state.pop("done", None) | |
| return new_state | |
| # No tool call or final answer - treat as regular response | |
| print("No actionable content found - continuing conversation") | |
| updated_ai_message = AIMessage(content=content, tool_calls=[]) | |
| new_state["messages"] = messages[:-1] + [updated_ai_message] | |
| new_state.pop("done", None) | |
| # DEBUG | |
| print(f"parse_react_output: Exit message count: {len(new_state['messages'])}") | |
| return new_state | |
| # 4. Improved call_tool_with_memory_management to prevent duplicate processing | |
| def call_tool_with_memory_management(state: AgentState) -> AgentState: | |
| """Process tool calls with memory management, avoiding duplicates.""" | |
| print("Running call_tool with memory management...") | |
| # Clear CUDA cache before processing | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"🧹 Cleared CUDA cache. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") | |
| except ImportError: | |
| pass | |
| except Exception as e: | |
| print(f"Error clearing CUDA cache: {e}") | |
| # Check if we have parsed tool calls from the condition function | |
| if 'parsed_tool_calls' in state and state.get('parsed_tool_calls'): | |
| print("Executing parsed tool calls...") | |
| return execute_parsed_tool_calls(state) | |
| # Fallback to original OpenAI-style tool calls handling | |
| messages = state.get("messages", []) | |
| if not messages: | |
| print("No messages found in state.") | |
| return state | |
| last_message = messages[-1] | |
| if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: | |
| print("No tool calls found in last message") | |
| return state | |
| # Avoid processing the same tool calls multiple times | |
| if hasattr(last_message, '_processed_tool_calls'): | |
| print("Tool calls already processed, skipping...") | |
| return state | |
| # Copy the messages to avoid mutating the original list | |
| new_messages = list(messages) | |
| print(f"Processing {len(last_message.tool_calls)} tool calls from last message") | |
| # Get file_path from state to pass to tools | |
| file_path_to_pass = state.get('file_path') | |
| for i, tool_call_item in enumerate(last_message.tool_calls): | |
| # Handle both dict and object-style tool calls | |
| if isinstance(tool_call_item, dict): | |
| tool_name = tool_call_item.get("name", "") | |
| raw_args = tool_call_item.get("args") | |
| tool_call_id = tool_call_item.get("id", str(uuid.uuid4())) | |
| elif hasattr(tool_call_item, "name") and hasattr(tool_call_item, "id"): | |
| tool_name = getattr(tool_call_item, "name", "") | |
| raw_args = getattr(tool_call_item, "args", None) | |
| tool_call_id = getattr(tool_call_item, "id", str(uuid.uuid4())) | |
| else: | |
| print(f"Skipping malformed tool call item: {tool_call_item}") | |
| continue | |
| print(f"Processing tool call {i+1}: {tool_name}") | |
| # Find the matching tool | |
| selected_tool = None | |
| for tool_instance in tools: | |
| if tool_instance.name.lower() == tool_name.lower(): | |
| selected_tool = tool_instance | |
| break | |
| if not selected_tool: | |
| tool_result = f"Error: Tool '{tool_name}' not found. Available tools: {', '.join(t.name for t in tools)}" | |
| print(f"Tool not found: {tool_name}") | |
| else: | |
| try: | |
| # Prepare the arguments for the tool.run() method | |
| tool_run_input_dict = {} | |
| if isinstance(raw_args, dict): | |
| tool_run_input_dict = raw_args.copy() | |
| elif raw_args is not None: | |
| tool_run_input_dict["query"] = str(raw_args) | |
| # Add file_path to the dictionary for the tool | |
| tool_run_input_dict['file_path'] = file_path_to_pass | |
| print(f"Executing {tool_name} with args: {tool_run_input_dict} ...") | |
| tool_result = selected_tool.run(tool_run_input_dict) | |
| #ipdb.set_trace() | |
| # Aggressive truncation to prevent memory issues | |
| if not isinstance(tool_result, str): | |
| tool_result = str(tool_result) | |
| max_length = 18000 if "wikipedia" in tool_name.lower() else 18000 | |
| if len(tool_result) > max_length: | |
| original_length = len(tool_result) | |
| tool_result = tool_result[:max_length] + f"... [Result truncated from {original_length} to {max_length} chars to prevent memory issues]" | |
| print(f"📄 Truncated result to {max_length} characters") | |
| print(f"Tool result length: {len(tool_result)} characters") | |
| except Exception as e: | |
| tool_result = f"Error executing tool '{tool_name}': {str(e)}" | |
| print(f"Tool execution error: {e}") | |
| # Create tool message - ONLY ONE PER TOOL CALL | |
| tool_message = ToolMessage( | |
| content=tool_result, | |
| name=tool_name, | |
| tool_call_id=tool_call_id | |
| ) | |
| new_messages.append(tool_message) | |
| print(f"Added tool message for {tool_name}") | |
| # Mark the last message as processed to prevent re-processing | |
| if hasattr(last_message, '__dict__'): | |
| last_message._processed_tool_calls = True | |
| # Update the state | |
| new_state = state.copy() | |
| new_state["messages"] = new_messages | |
| # Clear CUDA cache after processing | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"🧹 Cleared CUDA cache post-processing. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") | |
| except ImportError: | |
| pass | |
| except Exception as e: | |
| print(f"Error clearing CUDA cache post-processing: {e}") | |
| return new_state | |
| # 3. Enhanced execute_parsed_tool_calls to prevent duplicate observations | |
| def execute_parsed_tool_calls(state: AgentState): | |
| """ | |
| Execute tool calls that were parsed from the Thought/Action/Action Input format. | |
| This is called by call_tool when parsed_tool_calls are present in state. | |
| """ | |
| # Tool name mappings | |
| tool_name_mappings = { | |
| 'wikipedia_semantic_search': 'wikipedia_tool', | |
| 'wikipedia': 'wikipedia_tool', | |
| 'search': 'enhanced_search', | |
| 'duckduckgo_search': 'enhanced_search', | |
| 'web_search': 'enhanced_search', | |
| 'enhanced_search': 'enhanced_search', | |
| 'youtube_screenshot_qa_tool': 'youtube_tool', | |
| 'youtube': 'youtube_tool', | |
| 'youtube_transcript_extractor': 'youtube_transcript_extractor', | |
| 'youtube_audio_tool': 'youtube_transcript_extractor' | |
| } | |
| # Create a lookup by tool names | |
| tools_by_name = {} | |
| for tool in tools: | |
| tools_by_name[tool.name.lower()] = tool | |
| # Copy messages to avoid mutation during iteration | |
| new_messages = list(state["messages"]) | |
| # Process each tool call ONCE | |
| for tool_call in state['parsed_tool_calls']: | |
| action = tool_call['action'] | |
| action_input = tool_call['action_input'] | |
| normalized_action = tool_call['normalized_action'] | |
| print(f"🚀 Executing tool: {action} with input: {action_input}") | |
| # Find the tool instance | |
| tool_instance = None | |
| if normalized_action in tools_by_name: | |
| tool_instance = tools_by_name[normalized_action] | |
| elif normalized_action in tool_name_mappings: | |
| mapped_name = tool_name_mappings[normalized_action] | |
| if mapped_name in tools_by_name: | |
| tool_instance = tools_by_name[mapped_name] | |
| if tool_instance: | |
| try: | |
| # Pass file_path if the tool expects it | |
| if hasattr(tool_instance, 'run'): | |
| if 'file_path' in tool_instance.run.__code__.co_varnames: | |
| result = tool_instance.run(action_input, file_path=state.get('file_path')) | |
| else: | |
| result = tool_instance.run(action_input) | |
| else: | |
| result = str(tool_instance) | |
| # Truncate long results | |
| if len(result) > 6000: | |
| result = result[:6000] + "... [Result truncated due to length]" | |
| # Create a SINGLE observation message | |
| from langchain_core.messages import ToolMessage | |
| tool_message = ToolMessage( | |
| content=f"Observation: {result}", | |
| name=action, | |
| tool_call_id=str(uuid.uuid4()) | |
| ) | |
| new_messages.append(tool_message) | |
| print(f"✅ Tool '{action}' executed successfully") | |
| except Exception as e: | |
| print(f"❌ Error executing tool '{action}': {e}") | |
| from langchain_core.messages import ToolMessage | |
| error_message = ToolMessage( | |
| content=f"Observation: Error executing '{action}': {str(e)}", | |
| name=action, | |
| tool_call_id=str(uuid.uuid4()) | |
| ) | |
| new_messages.append(error_message) | |
| else: | |
| print(f"❌ Tool '{action}' not found in available tools") | |
| available_tool_names = list(tools_by_name.keys()) | |
| from langchain_core.messages import ToolMessage | |
| error_message = ToolMessage( | |
| content=f"Observation: Tool '{action}' not found. Available tools: {', '.join(available_tool_names)}", | |
| name=action, | |
| tool_call_id=str(uuid.uuid4()) | |
| ) | |
| new_messages.append(error_message) | |
| # Update state with new messages and clear parsed tool calls | |
| new_state = state.copy() | |
| new_state["messages"] = new_messages | |
| new_state['parsed_tool_calls'] = [] # Clear to prevent re-execution | |
| return new_state | |
| # 1. Add loop detection to your AgentState | |
| def should_continue(state: AgentState) -> str: | |
| """Enhanced continuation logic with better limits.""" | |
| print("Running should_continue...") | |
| # Check done flag first | |
| if state.get("done", False): | |
| print("✅ Done flag is True - ending") | |
| return "end" | |
| messages = state["messages"] | |
| # More aggressive message limit | |
| #if len(messages) > 20: # Reduced from 15 | |
| # print(f"⚠️ Message limit reached ({len(messages)}/20) - forcing end") | |
| # return "end" | |
| # Check for repeated patterns (stuck in loop) | |
| if len(messages) >= 6: | |
| recent_contents = [str(msg.content)[:100] for msg in messages[-6:] if hasattr(msg, 'content')] | |
| if len(set(recent_contents)) < 3: # Too much repetition | |
| print("🔄 Detected repetitive pattern - ending") | |
| return "end" | |
| print(f"📊 Continuing... ({len(messages)} messages so far)") | |
| return "continue" | |
| def route_after_parse_react(state: AgentState) -> str: | |
| """Determines the next step after parsing LLM output, prioritizing end state.""" | |
| if state.get("done", False): # Check if parse_react_output decided we are done | |
| return "end_processing" | |
| # Original logic: check for tool calls in the last message | |
| # Ensure messages list and last message exist before checking tool_calls | |
| messages = state.get("messages", []) | |
| if messages: | |
| last_message = messages[-1] | |
| if hasattr(last_message, 'tool_calls') and last_message.tool_calls: | |
| return "call_tool" | |
| return "call_llm" | |
| # --- Graph Construction --- | |
| # --- Graph Construction --- | |
| def create_memory_safe_workflow(): | |
| """Create a workflow with memory management and loop prevention.""" | |
| # These models are initialized here but might be better managed if they need to be released/reinitialized | |
| # like you attempt in run_agent. Consider passing them or managing their lifecycle carefully. | |
| hf_pipe = create_llm_pipeline() | |
| llm = HuggingFacePipeline(pipeline=hf_pipe) | |
| # vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly | |
| # processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used | |
| # model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used | |
| workflow = StateGraph(AgentState) | |
| # Bind the llm_model to the call_llm_with_memory_management function | |
| bound_call_llm = partial(call_llm_with_memory_management, llm_model=llm) | |
| # Add nodes with memory-safe versions | |
| workflow.add_node("call_llm", bound_call_llm) # Use the bound version here | |
| workflow.add_node("parse_react_output", parse_react_output) | |
| workflow.add_node("call_tool", call_tool_with_memory_management) # Ensure this doesn't also need llm if it calls back directly | |
| # Set entry point | |
| workflow.set_entry_point("call_llm") | |
| # Add conditional edges | |
| workflow.add_conditional_edges( | |
| "call_llm", | |
| should_continue, | |
| { | |
| "continue": "parse_react_output", | |
| "end": END | |
| } | |
| ) | |
| workflow.add_conditional_edges( | |
| "parse_react_output", | |
| route_after_parse_react, | |
| { | |
| "call_tool": "call_tool", | |
| "call_llm": "call_llm", | |
| "end_processing": END | |
| } | |
| ) | |
| workflow.add_edge("call_tool", "call_llm") | |
| return workflow.compile() | |
| def count_english_words(text): | |
| # Remove punctuation, lowercase, split into words | |
| table = str.maketrans('', '', string.punctuation) | |
| words_in_text = text.translate(table).lower().split() | |
| return sum(1 for word in words_in_text if word in english_words) | |
| def fix_backwards_text(text): | |
| reversed_text = text[::-1] | |
| original_count = count_english_words(text) | |
| reversed_count = count_english_words(reversed_text) | |
| if reversed_count > original_count: | |
| return reversed_text | |
| else: | |
| return text | |
| # --- Run the Agent --- | |
| # Enhanced system prompt for better behavior | |
| def run_agent(agent, state: AgentState): | |
| """Enhanced agent initialization with better prompt and hallucination prevention.""" | |
| global WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, YOUTUBE_AUDIO_TOOL, AUDIO_TRANSCRIPTION_TOOL, EXCEL_TOOL, PYTHON_TOOL, COMMUTATIVITY_TOOL, tools | |
| # Initialize tools | |
| WIKIPEDIA_TOOL = WikipediaSearchToolWithFAISS() | |
| SEARCH_TOOL = EnhancedDuckDuckGoSearchTool(max_results=3, max_chars_per_page=18000) | |
| YOUTUBE_TOOL = EnhancedYoutubeScreenshotQA() | |
| YOUTUBE_AUDIO_TOOL = YouTubeTranscriptExtractor() | |
| AUDIO_TRANSCRIPTION_TOOL = AudioTranscriptionTool() | |
| EXCEL_TOOL = ExcelReaderTool() | |
| PYTHON_TOOL = PythonExecutorTool() | |
| COMMUTATIVITY_TOOL = CommutativityAnalysisTool() | |
| tools = [WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_AUDIO_TOOL, YOUTUBE_TOOL, AUDIO_TRANSCRIPTION_TOOL, EXCEL_TOOL, PYTHON_TOOL, COMMUTATIVITY_TOOL] | |
| formatted_tools_description = render_text_description(tools) | |
| current_date_str = datetime.now().strftime("%Y-%m-%d") | |
| # Enhanced system prompt with stricter boundaries | |
| system_content = f"""You are an AI assistant with access to these tools: | |
| {formatted_tools_description} | |
| CRITICAL INSTRUCTIONS: | |
| 1. Answer ONLY the question asked by the human | |
| 2. Do NOT generate additional questions or continue conversations | |
| 3. Use tools ONLY when you need specific information you don't know | |
| 4. After using a tool, provide your FINAL ANSWER immediately | |
| 5. STOP after giving your FINAL ANSWER - do not continue | |
| 6. Do not repeat words in the question in the answer | |
| FORMAT for tool use: | |
| Thought: <brief reasoning> | |
| Action: <exact_tool_name> | |
| Action Input: <tool_input> | |
| When you have the answer, immediately provide: | |
| FINAL ANSWER: [concise answer only] | |
| ANSWER FORMAT: | |
| - Numbers: no commas, no units unless specified | |
| - Questions on "how many" should be answered with a number ONLY | |
| - Strings: no articles, no abbreviations, digits in plain text | |
| - Lists: comma-separated either in ascending numeric order or alphabetical order as requested | |
| - Be extremely brief and concise | |
| - Do not provide additional context or explanations | |
| - Do not provide parentheticals | |
| IMPORTANT: You are responding to ONE question only. Do not ask follow-up questions or generate additional dialogue. | |
| Current date: {current_date_str} | |
| """ | |
| query = fix_backwards_text(state['question']) | |
| # Check for YouTube URLs | |
| yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" | |
| if re.search(yt_pattern, query): | |
| url_match = re.search(r"(https?://[^\s]+)", query) | |
| if url_match: | |
| state['youtube_url'] = url_match.group(0) | |
| # Initialize messages | |
| system_message = SystemMessage(content=system_content) | |
| human_message = HumanMessage(content=query) | |
| state['messages'] = [system_message, human_message] | |
| state["done"] = False | |
| # Run the agent | |
| result = agent.invoke(state) | |
| # Cleanup | |
| if result.get("done"): | |
| #torch.cuda.empty_cache() | |
| #torch.cuda.ipc_collect() | |
| gc.collect() | |
| print("🧹 Released GPU memory after completion") | |
| return result["messages"] | |