Spaces:
Sleeping
Sleeping
| # Standard Library | |
| import os | |
| import re | |
| import tempfile | |
| import string | |
| import glob | |
| import shutil | |
| import gc | |
| import uuid | |
| import signal | |
| from datetime import datetime | |
| from io import BytesIO | |
| from contextlib import contextmanager | |
| from langchain_huggingface import HuggingFacePipeline | |
| from typing import TypedDict, List, Optional, Dict, Any, Annotated, Literal, Union, Tuple, Set | |
| import time | |
| from collections import Counter | |
| # Third-Party Packages | |
| import cv2 | |
| import requests | |
| import wikipedia | |
| import spacy | |
| import yt_dlp | |
| import librosa | |
| from PIL import Image | |
| from bs4 import BeautifulSoup | |
| from duckduckgo_search import DDGS | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import BlipProcessor, BlipForQuestionAnswering, pipeline, AutoTokenizer | |
| # LangChain Ecosystem | |
| from langchain.docstore.document import Document | |
| from langchain.prompts import PromptTemplate | |
| from langchain_community.document_loaders import WikipediaLoader | |
| from langchain_huggingface import HuggingFaceEndpoint | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.vectorstores import FAISS | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.schema import Document | |
| from langchain_community.tools import DuckDuckGoSearchRun | |
| from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, BaseMessage, SystemMessage, ToolMessage | |
| from langchain_core.tools import BaseTool, StructuredTool, tool, render_text_description | |
| from langchain_core.documents import Document | |
| # LangGraph | |
| from langgraph.graph import START, END, StateGraph | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| # PyTorch | |
| import torch | |
| from functools import partial | |
| from transformers import pipeline | |
| # Additional Utilities | |
| from datetime import datetime | |
| from urllib.parse import urljoin, urlparse | |
| import logging | |
| nlp = spacy.load("en_core_web_sm") | |
| logger = logging.getLogger(__name__) | |
| # --- Model Configuration --- | |
| def create_llm_pipeline(): | |
| #model_id = "meta-llama/Llama-2-13b-chat-hf" | |
| #model_id = "meta-llama/Llama-3.3-70B-Instruct" | |
| #model_id = "mistralai/Mistral-Small-24B-Base-2501" | |
| model_id = "mistralai/Mistral-7B-Instruct-v0.3" | |
| #model_id = "Qwen/Qwen2-7B-Instruct" | |
| # Load tokenizer explicitly with fast version | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| use_fast=True, # Force fast tokenizer | |
| add_prefix_space=True # Only if actually needed | |
| ) | |
| return pipeline( | |
| "text-generation", | |
| model=model_id, | |
| tokenizer = tokenizer, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| max_new_tokens=1024, | |
| temperature=0.1 | |
| ) | |
| # Define file extension sets for each category | |
| PICTURE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'} | |
| AUDIO_EXTENSIONS = {'.mp3', '.wav', '.aac', '.flac', '.ogg', '.m4a', '.wma'} | |
| CODE_EXTENSIONS = {'.py', '.js', '.java', '.cpp', '.c', '.cs', '.rb', '.go', '.php', '.html', '.css', '.ts'} | |
| SPREADSHEET_EXTENSIONS = { | |
| '.xls', '.xlsx', '.xlsm', '.xlsb', '.xlt', '.xltx', '.xltm', | |
| '.ods', '.ots', '.csv', '.tsv', '.sxc', '.stc', '.dif', '.gsheet', | |
| '.numbers', '.numbers-tef', '.nmbtemplate', '.fods', '.123', '.wk1', '.wk2', | |
| '.wks', '.wku', '.wr1', '.gnumeric', '.gnm', '.xml', '.pmvx', '.pmdx', | |
| '.pmv', '.uos', '.txt' | |
| } | |
| def get_file_type(filename: str) -> str: | |
| if not filename or '.' not in filename or filename == '': | |
| return '' | |
| ext = filename.lower().rsplit('.', 1)[-1] | |
| dot_ext = f'.{ext}' | |
| if dot_ext in PICTURE_EXTENSIONS: | |
| return 'picture' | |
| elif dot_ext in AUDIO_EXTENSIONS: | |
| return 'audio' | |
| elif dot_ext in CODE_EXTENSIONS: | |
| return 'code' | |
| elif dot_ext in SPREADSHEET_EXTENSIONS: | |
| return 'spreadsheet' | |
| else: | |
| return 'unknown' | |
| def write_bytes_to_temp_dir(file_bytes: bytes, file_name: str) -> str: | |
| """ | |
| Writes bytes to a file in the system temporary directory using the provided file_name. | |
| Returns the full path to the saved file. | |
| The file will persist until manually deleted or the OS cleans the temp directory. | |
| """ | |
| temp_dir = "/tmp" # /tmp is always writable in Hugging Face Spaces | |
| os.makedirs(temp_dir, exist_ok=True) | |
| file_path = os.path.join(temp_dir, file_name) | |
| with open(file_path, 'wb') as f: | |
| f.write(file_bytes) | |
| print(f"File written to: {file_path}") | |
| return file_path | |
| def extract_final_answer(text: str) -> str: | |
| """ | |
| Returns the substring starting from the last occurrence of 'FINAL ANSWER:' (case-insensitive) | |
| to the end of the string, with any trailing punctuation removed. | |
| If not found, returns an empty string. | |
| """ | |
| marker = "FINAL ANSWER:" | |
| idx = text.lower().rfind(marker.lower()) | |
| if idx == -1: | |
| return "" | |
| result = text[idx:].strip() | |
| # Remove trailing punctuation | |
| return result.rstrip(string.punctuation + " ") | |
| class EnhancedDuckDuckGoSearchTool(BaseTool): | |
| name: str = "enhanced_search" | |
| description: str = ( | |
| "Performs a DuckDuckGo web search and retrieves actual content from the top web results. " | |
| "Input should be a search query string. " | |
| "Returns search results with extracted content from web pages, making it much more useful for answering questions. " | |
| "Use this tool when you need up-to-date information, details about current events, or when other tools do not provide sufficient or recent answers. " | |
| "Ideal for topics that require the latest news, recent developments, or information not covered in static sources." | |
| ) | |
| max_results: int = 3 | |
| max_chars_per_page: int = 3000 | |
| session: Any = None # Now it's optional and defaults to None | |
| # Use model_post_init for initialization logic in Pydantic v2+ | |
| def model_post_init(self, __context: Any) -> None: | |
| super().model_post_init(__context) | |
| # Initialize HTTP session here | |
| self.session = requests.Session() | |
| self.session.headers.update({ | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', | |
| 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8', | |
| 'Accept-Language': 'en-US,en;q=0.5', | |
| 'Accept-Encoding': 'gzip, deflate', | |
| 'Connection': 'keep-alive', | |
| 'Upgrade-Insecure-Requests': '1', | |
| }) | |
| def _search_duckduckgo(self, query: str) -> List[Dict]: | |
| """Perform DuckDuckGo search and return results.""" | |
| try: | |
| with DDGS() as ddgs: | |
| results = list(ddgs.text(query, max_results=self.max_results)) | |
| return results | |
| except Exception as e: | |
| logger.error(f"DuckDuckGo search failed: {e}") | |
| return [] | |
| def _extract_content_from_url(self, url: str, timeout: int = 10) -> Optional[str]: | |
| """Extract clean text content from a web page.""" | |
| try: | |
| # Skip certain file types | |
| if any(url.lower().endswith(ext) for ext in ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']): | |
| return "Content type not supported for extraction" | |
| response = self.session.get(url, timeout=timeout, allow_redirects=True) | |
| response.raise_for_status() | |
| # Check content type | |
| content_type = response.headers.get('content-type', '').lower() | |
| if 'text/html' not in content_type: | |
| return "Non-HTML content detected" | |
| soup = BeautifulSoup(response.content, 'html.parser') | |
| # Remove script and style elements | |
| for script in soup(["script", "style", "nav", "header", "footer", "aside", "form"]): | |
| script.decompose() | |
| # Try to find main content areas | |
| main_content = None | |
| for selector in ['main', 'article', '.content', '#content', '.post', '.entry']: | |
| main_content = soup.select_one(selector) | |
| if main_content: | |
| break | |
| if not main_content: | |
| main_content = soup.find('body') or soup | |
| # Extract text | |
| text = main_content.get_text(separator='\n', strip=True) | |
| # Clean up the text | |
| lines = [line.strip() for line in text.split('\n') if line.strip()] | |
| text = '\n'.join(lines) | |
| # Remove excessive whitespace | |
| text = re.sub(r'\n{3,}', '\n\n', text) | |
| text = re.sub(r' {2,}', ' ', text) | |
| # Truncate if too long | |
| if len(text) > self.max_chars_per_page: | |
| text = text[:self.max_chars_per_page] + "\n[Content truncated...]" | |
| return text | |
| except requests.exceptions.Timeout: | |
| return "Page loading timed out" | |
| except requests.exceptions.RequestException as e: | |
| return f"Failed to retrieve page: {str(e)}" | |
| except Exception as e: | |
| logger.error(f"Content extraction failed for {url}: {e}") | |
| return "Failed to extract content from page" | |
| def _format_search_result(self, result: Dict, content: str) -> str: | |
| """Format a single search result with its content.""" | |
| title = result.get('title', 'No title') | |
| url = result.get('href', 'No URL') | |
| snippet = result.get('body', 'No snippet') | |
| formatted = f""" | |
| 🔍 **{title}** | |
| URL: {url} | |
| Snippet: {snippet} | |
| 📄 **Page Content:** | |
| {content} | |
| --- | |
| """ | |
| return formatted | |
| def run(self, query: str) -> str: | |
| """Execute the enhanced search.""" | |
| if not query or not query.strip(): | |
| return "Please provide a search query." | |
| query = query.strip() | |
| logger.info(f"Searching for: {query}") | |
| # Perform DuckDuckGo search | |
| search_results = self._search_duckduckgo(query) | |
| if not search_results: | |
| return f"No search results found for query: {query}" | |
| # Process each result and extract content | |
| enhanced_results = [] | |
| processed_count = 0 | |
| for i, result in enumerate(search_results[:self.max_results]): | |
| url = result.get('href', '') | |
| if not url: | |
| continue | |
| logger.info(f"Processing result {i+1}: {url}") | |
| # Extract content from the page | |
| content = self._extract_content_from_url(url) | |
| if content and len(content.strip()) > 50: # Only include results with substantial content | |
| formatted_result = self._format_search_result(result, content) | |
| enhanced_results.append(formatted_result) | |
| processed_count += 1 | |
| # Small delay to be respectful to servers | |
| time.sleep(0.5) | |
| if not enhanced_results: | |
| return f"Search completed but no content could be extracted from the pages for query: {query}" | |
| # Compile final response | |
| response = f"""🔍 **Enhanced Search Results for: "{query}"** | |
| Found {len(search_results)} results, successfully processed {processed_count} pages with content. | |
| {''.join(enhanced_results)} | |
| 💡 **Summary:** Retrieved and processed content from {processed_count} web pages to provide comprehensive information about your search query. | |
| """ | |
| # Ensure the response isn't too long | |
| if len(response) > 8000: | |
| response = response[:8000] + "\n[Response truncated to prevent memory issues]" | |
| return response | |
| def _run(self, query: str) -> str: | |
| """Required by BaseTool interface.""" | |
| return self.run(query) | |
| # --- Agent State Definition --- | |
| class AgentState(TypedDict): | |
| messages: Annotated[List[AnyMessage], lambda x, y: x + y] | |
| done: bool = False # Default value of False | |
| question: str | |
| task_id: str | |
| input_file: Optional[bytes] | |
| file_type: Optional[str] | |
| context: List[Document] # Using LangChain's Document class | |
| file_path: Optional[str] | |
| youtube_url: Optional[str] | |
| answer: Optional[str] | |
| frame_answers: Optional[list] | |
| def fetch_page_with_tables(page_title): | |
| """ | |
| Fetches Wikipedia page content and extracts all tables as readable text. | |
| Returns a tuple: (main_text, [table_texts]) | |
| """ | |
| # Fetch the page object | |
| page = wikipedia.page(page_title) | |
| main_text = page.content | |
| # Get the HTML for table extraction | |
| html = page.html() | |
| soup = BeautifulSoup(html, 'html.parser') | |
| tables = soup.find_all('table') | |
| table_texts = [] | |
| for table in tables: | |
| rows = table.find_all('tr') | |
| table_lines = [] | |
| for row in rows: | |
| cells = row.find_all(['th', 'td']) | |
| cell_texts = [cell.get_text(strip=True) for cell in cells] | |
| if cell_texts: | |
| # Format as Markdown table row | |
| table_lines.append(" | ".join(cell_texts)) | |
| if table_lines: | |
| table_text = "\n".join(table_lines) | |
| table_texts.append(table_text) | |
| return main_text, table_texts | |
| class WikipediaSearchToolWithFAISS(BaseTool): | |
| name: str = "wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval" | |
| description: str = ( | |
| "Fetches content from multiple Wikipedia pages based on intelligent NLP query processing " | |
| "of various search candidates, with strong prioritization of query entities. It then performs " | |
| "entity-focused semantic search across all fetched content to find the most relevant information, " | |
| "with improved retrieval for lists like discographies. Uses spaCy for named entity " | |
| "recognition and query enhancement. Input should be a search query or topic. " | |
| "Note: Uses the current live version of Wikipedia." | |
| ) | |
| embedding_model_name: str = "all-MiniLM-L6-v2" | |
| chunk_size: int = 4000 | |
| chunk_overlap: int = 250 # Maintained moderate overlap | |
| top_k_results: int = 3 | |
| spacy_model: str = "en_core_web_sm" | |
| # Increased multiplier to fetch more candidates per semantic query variant | |
| semantic_search_candidate_multiplier: int = 1 # Was 2, increased to 3, consider 4 if still problematic | |
| def __init__(self, **kwargs): | |
| super().__init__(**kwargs) | |
| try: | |
| self._nlp = spacy.load(self.spacy_model) | |
| print(f"Loaded spaCy model: {self.spacy_model}") | |
| self._embedding_model = HuggingFaceEmbeddings(model_name=self.embedding_model_name) | |
| # Refined separators for better handling of Wikipedia lists and sections | |
| self._text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=self.chunk_size, | |
| chunk_overlap=self.chunk_overlap, | |
| separators=[ | |
| "\n\n== ", "\n\n=== ", "\n\n==== ", # Section headers (keep with following content) | |
| "\n\n\n", "\n\n", # Multiple newlines (paragraph breaks) | |
| "\n* ", "\n- ", "\n# ", # List items | |
| "\n", ". ", "! ", "? ", # Sentence breaks after newline, common punctuation | |
| " ", "" # Word and character level | |
| ] | |
| ) | |
| except OSError as e: | |
| print(f"Error loading spaCy model '{self.spacy_model}': {e}") | |
| print("Try running: python -m spacy download en_core_web_sm") | |
| self._nlp = None | |
| self._embedding_model = None | |
| self._text_splitter = None | |
| except Exception as e: | |
| print(f"Error initializing WikipediaSearchToolWithFAISS components: {e}") | |
| self._nlp = None | |
| self._embedding_model = None | |
| self._text_splitter = None | |
| def _extract_entities_and_keywords(self, query: str) -> Tuple[List[str], List[str], str]: | |
| if not self._nlp: | |
| return [], [], query | |
| doc = self._nlp(query) | |
| main_entities = [ent.text for ent in doc.ents if ent.label_ in ["PERSON", "ORG", "GPE", "EVENT", "WORK_OF_ART"]] | |
| keywords = [token.lemma_.lower() for token in doc if token.pos_ in ["NOUN", "PROPN", "ADJ"] and not token.is_stop and not token.is_punct and len(token.text) > 2] | |
| main_entities = list(dict.fromkeys(main_entities)) | |
| keywords = list(dict.fromkeys(keywords)) | |
| processed_tokens = [token.lemma_ for token in doc if not token.is_stop and not token.is_punct and token.text.strip()] | |
| processed_query = " ".join(processed_tokens) | |
| return main_entities, keywords, processed_query | |
| def _generate_search_candidates(self, query: str, main_entities: List[str], keywords: List[str], processed_query: str) -> List[str]: | |
| candidates_set = set() | |
| entity_prefix = main_entities[0] if main_entities else None | |
| for me in main_entities: | |
| candidates_set.add(me) | |
| candidates_set.add(query) | |
| if processed_query and processed_query != query: | |
| candidates_set.add(processed_query) | |
| if entity_prefix and keywords: | |
| first_entity_lower = entity_prefix.lower() | |
| for kw in keywords[:3]: | |
| if kw not in first_entity_lower and len(kw) > 2: | |
| candidates_set.add(f"{entity_prefix} {kw}") | |
| keyword_combo_short = " ".join(k for k in keywords[:2] if k not in first_entity_lower and len(k)>2) | |
| if keyword_combo_short: candidates_set.add(f"{entity_prefix} {keyword_combo_short}") | |
| if len(main_entities) > 1: | |
| candidates_set.add(" ".join(main_entities[:2])) | |
| if keywords: | |
| keyword_combo = " ".join(keywords[:2]) | |
| if entity_prefix: | |
| candidate_to_add = f"{entity_prefix} {keyword_combo}" | |
| if not any(c.lower() == candidate_to_add.lower() for c in candidates_set): | |
| candidates_set.add(candidate_to_add) | |
| elif not main_entities: | |
| candidates_set.add(keyword_combo) | |
| ordered_candidates = [] | |
| for me in main_entities: | |
| if me not in ordered_candidates: ordered_candidates.append(me) | |
| for c in list(candidates_set): | |
| if c and c.strip() and c not in ordered_candidates: ordered_candidates.append(c) | |
| print(f"Generated {len(ordered_candidates)} search candidates for Wikipedia page lookup (entity-prioritized): {ordered_candidates}") | |
| return ordered_candidates | |
| def _smart_wikipedia_search(self, query_text: str, main_entities_from_query: List[str], keywords_from_query: List[str], processed_query_text: str) -> List[Tuple[str, str]]: | |
| candidates = self._generate_search_candidates(query_text, main_entities_from_query, keywords_from_query, processed_query_text) | |
| found_pages_data: List[Tuple[str, str]] = [] | |
| processed_page_titles: Set[str] = set() | |
| for i, candidate_query in enumerate(candidates): | |
| print(f"\nProcessing candidate {i+1}/{len(candidates)} for page: '{candidate_query}'") | |
| page_object = None | |
| final_page_title = None | |
| is_candidate_entity_focused = any(me.lower() in candidate_query.lower() for me in main_entities_from_query) if main_entities_from_query else False | |
| try: | |
| try: | |
| page_to_load = candidate_query | |
| suggest_mode = True # Default to auto_suggest=True | |
| if is_candidate_entity_focused and main_entities_from_query: | |
| try: # Attempt precise match first for entity-focused candidates | |
| temp_page = wikipedia.page(page_to_load, auto_suggest=False, redirect=True) | |
| suggest_mode = False # Flag that precise match worked | |
| except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError): | |
| print(f" - auto_suggest=False failed for entity-focused '{page_to_load}', trying with auto_suggest=True.") | |
| # Fallthrough to auto_suggest=True below if this fails | |
| if suggest_mode: # If not attempted or failed with auto_suggest=False | |
| temp_page = wikipedia.page(page_to_load, auto_suggest=True, redirect=True) | |
| final_page_title = temp_page.title | |
| if is_candidate_entity_focused and main_entities_from_query: | |
| title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
| if not title_matches_main_entity: | |
| print(f" ! Page title '{final_page_title}' (from entity-focused candidate '{candidate_query}') " | |
| f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
| continue | |
| if final_page_title in processed_page_titles: | |
| print(f" ~ Already processed '{final_page_title}'") | |
| continue | |
| page_object = temp_page | |
| print(f" ✓ Direct hit/suggestion for '{candidate_query}' -> '{final_page_title}'") | |
| except wikipedia.exceptions.PageError: | |
| if i < max(2, len(candidates) // 3) : # Try Wikipedia search for a smaller, more promising subset of candidates | |
| print(f" - Direct access failed for '{candidate_query}'. Trying Wikipedia search...") | |
| search_results = wikipedia.search(candidate_query, results=1) | |
| if not search_results: | |
| print(f" - No Wikipedia search results for '{candidate_query}'.") | |
| continue | |
| search_result_title = search_results[0] | |
| try: | |
| temp_page = wikipedia.page(search_result_title, auto_suggest=False, redirect=True) # Search results are usually canonical | |
| final_page_title = temp_page.title | |
| if is_candidate_entity_focused and main_entities_from_query: # Still check against original intent | |
| title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
| if not title_matches_main_entity: | |
| print(f" ! Page title '{final_page_title}' (from search for '{candidate_query}' -> '{search_result_title}') " | |
| f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
| continue | |
| if final_page_title in processed_page_titles: | |
| print(f" ~ Already processed '{final_page_title}'") | |
| continue | |
| page_object = temp_page | |
| print(f" ✓ Found via search '{candidate_query}' -> '{search_result_title}' -> '{final_page_title}'") | |
| except (wikipedia.exceptions.PageError, wikipedia.exceptions.DisambiguationError) as e_sr: | |
| print(f" ! Error/Disambiguation for search result '{search_result_title}': {e_sr}") | |
| else: | |
| print(f" - Direct access failed for '{candidate_query}'. Skipping further search for this lower priority candidate.") | |
| except wikipedia.exceptions.DisambiguationError as de: | |
| print(f" ! Disambiguation for '{candidate_query}'. Options: {de.options[:1]}") | |
| if de.options: | |
| option_title = de.options[0] | |
| try: | |
| temp_page = wikipedia.page(option_title, auto_suggest=False, redirect=True) | |
| final_page_title = temp_page.title | |
| if is_candidate_entity_focused and main_entities_from_query: # Check against original intent | |
| title_matches_main_entity = any(me.lower() in final_page_title.lower() for me in main_entities_from_query) | |
| if not title_matches_main_entity: | |
| print(f" ! Page title '{final_page_title}' (from disamb. of '{candidate_query}' -> '{option_title}') " | |
| f"does not strongly match main query entities: {main_entities_from_query}. Skipping.") | |
| continue | |
| if final_page_title in processed_page_titles: | |
| print(f" ~ Already processed '{final_page_title}'") | |
| continue | |
| page_object = temp_page | |
| print(f" ✓ Resolved disambiguation '{candidate_query}' -> '{option_title}' -> '{final_page_title}'") | |
| except Exception as e_dis_opt: | |
| print(f" ! Could not load disambiguation option '{option_title}': {e_dis_opt}") | |
| if page_object and final_page_title and (final_page_title not in processed_page_titles): | |
| # Extract main text | |
| main_text = page_object.content | |
| # Extract tables using BeautifulSoup | |
| try: | |
| html = page_object.html() | |
| soup = BeautifulSoup(html, 'html.parser') | |
| tables = soup.find_all('table') | |
| table_texts = [] | |
| for table in tables: | |
| rows = table.find_all('tr') | |
| table_lines = [] | |
| for row in rows: | |
| cells = row.find_all(['th', 'td']) | |
| cell_texts = [cell.get_text(strip=True) for cell in cells] | |
| if cell_texts: | |
| table_lines.append(" | ".join(cell_texts)) | |
| if table_lines: | |
| table_text = "\n".join(table_lines) | |
| table_texts.append(table_text) | |
| except Exception as e: | |
| print(f" !! Error extracting tables for '{final_page_title}': {e}") | |
| table_texts = [] | |
| # Combine main text and all table texts as separate chunks | |
| all_text_chunks = [main_text] + table_texts | |
| for chunk in all_text_chunks: | |
| found_pages_data.append((chunk, final_page_title)) | |
| processed_page_titles.add(final_page_title) | |
| print(f" -> Added page '{final_page_title}'. Main text length: {len(main_text)} | Tables extracted: {len(table_texts)}") | |
| except Exception as e: | |
| print(f" !! Unexpected error processing candidate '{candidate_query}': {e}") | |
| if not found_pages_data: print(f"\nCould not find any new, unique, entity-validated Wikipedia pages for query '{query_text}'.") | |
| else: print(f"\nFound {len(found_pages_data)} unique, validated page(s) for processing.") | |
| return found_pages_data | |
| def _enhance_semantic_search(self, query: str, vector_store, main_entities: List[str], keywords: List[str], processed_query: str) -> List[Document]: | |
| core_query_parts = set() | |
| core_query_parts.add(query) | |
| if processed_query != query: core_query_parts.add(processed_query) | |
| if keywords: core_query_parts.add(" ".join(keywords[:2])) | |
| section_phrases_templates = [] | |
| lower_query_terms = set(query.lower().split()) | set(k.lower() for k in keywords) | |
| section_keywords_map = { | |
| "discography": ["discography", "list of studio albums", "studio album titles and years", "albums by year", "album release dates", "official albums", "complete album list", "albums published"], | |
| "biography": ["biography", "life story", "career details", "background history"], | |
| "filmography": ["filmography", "list of films", "movie appearances", "acting roles"], | |
| } | |
| for section_term_key, specific_phrases_list in section_keywords_map.items(): | |
| # Check if the key (e.g., "discography") or any of its specific phrases (e.g. "list of studio albums") | |
| # are mentioned or implied by the query terms. | |
| if section_term_key in lower_query_terms or any(phrase_part in lower_query_terms for phrase_part in section_term_key.split()): | |
| section_phrases_templates.extend(specific_phrases_list) | |
| # Also check if phrases themselves are in query terms (e.g. query "list of albums by X") | |
| for phrase in specific_phrases_list: | |
| if phrase in query.lower(): # Check against original query for direct phrase matches | |
| section_phrases_templates.extend(specific_phrases_list) # Add all related if one specific is hit | |
| break | |
| section_phrases_templates = list(dict.fromkeys(section_phrases_templates)) # Deduplicate | |
| final_search_queries = set() | |
| if main_entities: | |
| entity_prefix = main_entities[0] | |
| final_search_queries.add(entity_prefix) | |
| for part in core_query_parts: | |
| final_search_queries.add(f"{entity_prefix} {part}" if entity_prefix.lower() not in part.lower() else part) | |
| for phrase_template in section_phrases_templates: | |
| final_search_queries.add(f"{entity_prefix} {phrase_template}") | |
| if "list of" in phrase_template or "history of" in phrase_template : | |
| final_search_queries.add(f"{phrase_template} of {entity_prefix}") | |
| else: | |
| final_search_queries.update(core_query_parts) | |
| final_search_queries.update(section_phrases_templates) | |
| deduplicated_queries = list(dict.fromkeys(sq for sq in final_search_queries if sq and sq.strip())) | |
| print(f"Generated {len(deduplicated_queries)} semantic search query variants (list-retrieval focused): {deduplicated_queries}") | |
| all_results_docs: List[Document] = [] | |
| seen_content_hashes: Set[int] = set() | |
| k_to_fetch = self.top_k_results * self.semantic_search_candidate_multiplier | |
| for search_query_variant in deduplicated_queries: | |
| try: | |
| results = vector_store.similarity_search_with_score(search_query_variant, k=k_to_fetch) | |
| print(f" Semantic search variant '{search_query_variant}' (k={k_to_fetch}) -> {len(results)} raw chunk(s) with scores.") | |
| for doc, score in results: # Assuming similarity_search_with_score returns (doc, score) | |
| content_hash = hash(doc.page_content[:250]) # Slightly more for hash uniqueness | |
| if content_hash not in seen_content_hashes: | |
| seen_content_hashes.add(content_hash) | |
| doc.metadata['retrieved_by_variant'] = search_query_variant | |
| doc.metadata['retrieval_score'] = float(score) # Store score | |
| all_results_docs.append(doc) | |
| except Exception as e: | |
| print(f" Error in semantic search for variant '{search_query_variant}': {e}") | |
| # Sort all collected unique results by score (FAISS L2 distance is lower is better) | |
| all_results_docs.sort(key=lambda x: x.metadata.get('retrieval_score', float('inf'))) | |
| print(f"Collected and re-sorted {len(all_results_docs)} unique chunks from all semantic query variants.") | |
| return all_results_docs[:self.top_k_results] | |
| def _run(self, query: str) -> str: | |
| if not self._nlp or not self._embedding_model or not self._text_splitter: | |
| print("ERROR: WikipediaSearchToolWithFAISS components not initialized properly.") | |
| return "Error: Wikipedia tool components not initialized properly. Please check server logs." | |
| try: | |
| print(f"\n--- Running {self.name} for query: '{query}' ---") | |
| main_entities, keywords, processed_query = self._extract_entities_and_keywords(query) | |
| print(f"Initial NLP Analysis - Main Entities: {main_entities}, Keywords: {keywords}, Processed Query: '{processed_query}'") | |
| fetched_pages_data = self._smart_wikipedia_search(query, main_entities, keywords, processed_query) | |
| if not fetched_pages_data: | |
| return (f"Could not find any relevant, entity-validated Wikipedia pages for the query '{query}'. " | |
| f"Main entities sought: {main_entities}") | |
| all_page_titles = [title for _, title in fetched_pages_data] | |
| print(f"\nSuccessfully fetched content for {len(fetched_pages_data)} Wikipedia page(s): {', '.join(all_page_titles)}") | |
| all_documents: List[Document] = [] | |
| for page_content, page_title in fetched_pages_data: | |
| chunks = self._text_splitter.split_text(page_content) | |
| if not chunks: | |
| print(f"Warning: Could not split content from Wikipedia page '{page_title}' into chunks.") | |
| continue | |
| for i, chunk_text in enumerate(chunks): | |
| all_documents.append(Document(page_content=chunk_text, metadata={ | |
| "source_page_title": page_title, | |
| "original_query": query, | |
| "chunk_index": i # Add chunk index for potential debugging or ordering | |
| })) | |
| print(f"Split content from '{page_title}' into {len(chunks)} chunks.") | |
| if not all_documents: | |
| return (f"Could not process content into searchable chunks from the fetched Wikipedia pages " | |
| f"({', '.join(all_page_titles)}) for query '{query}'.") | |
| print(f"\nTotal document chunks from all pages: {len(all_documents)}") | |
| print("Creating FAISS index from content of all fetched pages...") | |
| try: | |
| vector_store = FAISS.from_documents(all_documents, self._embedding_model) | |
| print("FAISS index created successfully.") | |
| except Exception as e: | |
| return f"Error creating FAISS vector store: {e}" | |
| print(f"\nPerforming enhanced semantic search across all collected content...") | |
| try: | |
| relevant_docs = self._enhance_semantic_search(query, vector_store, main_entities, keywords, processed_query) | |
| except Exception as e: | |
| return f"Error during semantic search: {e}" | |
| if not relevant_docs: | |
| return (f"No relevant information found within Wikipedia page(s) '{', '.join(list(dict.fromkeys(all_page_titles)))}' " | |
| f"for your query '{query}' using entity-focused semantic search with list retrieval.") | |
| unique_sources_in_results = list(dict.fromkeys([doc.metadata.get('source_page_title', 'Unknown Source') for doc in relevant_docs])) | |
| result_header = (f"Found {len(relevant_docs)} relevant piece(s) of information from Wikipedia page(s) " | |
| f"'{', '.join(unique_sources_in_results)}' for your query '{query}':\n") | |
| nlp_summary = (f"[Original Query NLP: Main Entities: {', '.join(main_entities) if main_entities else 'None'}, " | |
| f"Keywords: {', '.join(keywords[:5]) if keywords else 'None'}]\n\n") | |
| result_details = [] | |
| for i, doc in enumerate(relevant_docs): | |
| source_info = doc.metadata.get('source_page_title', 'Unknown Source') | |
| variant_info = doc.metadata.get('retrieved_by_variant', 'N/A') | |
| score_info = doc.metadata.get('retrieval_score', 'N/A') | |
| detail = (f"Result {i+1} (source: '{source_info}', score: {score_info:.4f})\n" | |
| f"(Retrieved by: '{variant_info}')\n{doc.page_content}") | |
| result_details.append(detail) | |
| final_result = result_header + nlp_summary + "\n\n---\n\n".join(result_details) | |
| print(f"\nReturning {len(relevant_docs)} relevant chunks from {len(set(all_page_titles))} source page(s).") | |
| return final_result.strip() | |
| except Exception as e: | |
| import traceback | |
| print(f"Unexpected error in {self.name}: {traceback.format_exc()}") | |
| return f"An unexpected error occurred: {str(e)}" | |
| # Example of creating the tool instance: | |
| # wikipedia_tool_faiss = WikipediaSearchToolWithFAISS() | |
| # To use this new tool in your agent, you would replace the old | |
| # `wikipedia_tool` instance with `wikipedia_tool_faiss` in your `tools` list. | |
| # For example: | |
| # tools = [wikipedia_tool_faiss, search_tool] | |
| # Create tool instances | |
| #wikipedia_tool = WikipediaSearchTool() | |
| # --- Define Call LLM function --- | |
| # 3. Improved LLM call with memory management | |
| def call_llm_with_memory_management(state: AgentState, llm_model) -> AgentState: # Added llm_model parameter | |
| """Call LLM with memory management, context truncation, and process response.""" | |
| print("Running call_llm with memory management...") | |
| # It's crucial to work with a copy of messages for modification within this step | |
| # The final state["messages"] should reflect the full history + new response. | |
| original_messages = list(state["messages"]) | |
| messages_for_llm_processing = list(state["messages"]) # Use this for truncation logic | |
| #ipdb.set_trace() | |
| # --- Context Truncation Logic --- | |
| system_message_content = None | |
| # Check if the first message is a system message and preserve it | |
| if messages_for_llm_processing and isinstance(messages_for_llm_processing[0], SystemMessage): | |
| system_message_content = messages_for_llm_processing[0] | |
| # Process only non-system messages for truncation count | |
| regular_messages = messages_for_llm_processing[1:] | |
| else: | |
| regular_messages = messages_for_llm_processing | |
| # Truncate context if too many messages (e.g., keep system + X most recent) | |
| # Max 10 messages total (e.g. 1 system + 9 others) | |
| max_regular_messages = 9 | |
| if len(regular_messages) > max_regular_messages: | |
| print(f"🔄 Truncating message count: {len(messages_for_llm_processing)} -> ~{max_regular_messages + (1 if system_message_content else 0)} messages") | |
| regular_messages = regular_messages[- (max_regular_messages -1):] # Keep X-1 most recent, to add user input later | |
| # Reconstruct messages for LLM call | |
| messages_for_llm = [] | |
| if system_message_content: | |
| messages_for_llm.append(system_message_content) | |
| messages_for_llm.extend(regular_messages) | |
| # Further truncate based on character count (rough proxy for tokens) | |
| total_chars = sum(len(str(msg.content)) for msg in messages_for_llm) | |
| # Example character limit, adjust based on your model (e.g. 8k chars for ~4k tokens) | |
| char_limit = 8000 | |
| if total_chars > char_limit: | |
| print(f"📏 Context too long ({total_chars} chars > {char_limit}), further truncation needed") | |
| # More aggressive truncation of regular messages | |
| chars_to_remove = total_chars - char_limit | |
| temp_regular_messages = list(regular_messages) # copy | |
| while sum(len(str(m.content)) for m in temp_regular_messages) > char_limit and temp_regular_messages: | |
| if system_message_content and sum(len(str(m.content)) for m in temp_regular_messages) + len(str(system_message_content.content)) <= char_limit : | |
| break # if removing one more makes it too small with system message | |
| print(f"Removing message: {temp_regular_messages[0].type} - {temp_regular_messages[0].content[:50]}...") | |
| temp_regular_messages.pop(0) | |
| regular_messages = temp_regular_messages | |
| messages_for_llm = [] # Rebuild | |
| if system_message_content: | |
| messages_for_llm.append(system_message_content) | |
| messages_for_llm.extend(regular_messages) | |
| print(f"Context truncated to {sum(len(str(m.content)) for m in messages_for_llm)} chars.") | |
| new_state = state.copy() # Start with a copy of the input state | |
| try: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"🧹 Pre-LLM CUDA cache cleared. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") | |
| print(f"Invoking LLM with {len(messages_for_llm)} messages.") | |
| # This is where you call your actual LLM | |
| formatted_input = "\n".join([f"[{msg.type.upper()}] {msg.content}" for msg in messages_for_llm]) | |
| print(f"\n\nFormatted input for LLM:\n\n{formatted_input}") | |
| llm_response_object = llm_model.invoke(formatted_input) | |
| #ipdb.set_trace() | |
| # The response_object is typically a BaseMessage subclass (e.g., AIMessage) | |
| # or a string for simpler LLMs. Adapt as needed. | |
| if isinstance(llm_response_object, BaseMessage): | |
| ai_message_response = llm_response_object # It's already a message object | |
| if not ai_message_response.content: # Ensure content is not empty | |
| ai_message_response.content = "" | |
| elif hasattr(llm_response_object, 'content'): # Some models might return a custom object with a content attribute | |
| ai_message_response = AIMessage(content=str(llm_response_object.content) if llm_response_object.content is not None else "") | |
| else: # Assuming it's a string for basic LLMs | |
| ai_message_response = AIMessage(content=str(llm_response_object) if llm_response_object is not None else "") | |
| print(f"LLM Response: {ai_message_response.content[:300]}...") # Print a snippet | |
| # Append the LLM's response to the original full list of messages | |
| final_messages = original_messages + [ai_message_response] | |
| new_state["messages"] = final_messages | |
| new_state.pop("done", None) # LLM responded, so not 'done' by default | |
| except Exception as e: | |
| print(f"LLM call failed: {e}") | |
| error_message_content = f"LLM call failed with error: {str(e)}. Input consisted of {len(messages_for_llm)} messages." | |
| if "out of memory" in str(e).lower(): | |
| print("🚨 CUDA OOM detected during LLM call! Implementing emergency cleanup...") | |
| error_message_content = f"LLM failed due to Out of Memory: {str(e)}." | |
| try: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| except Exception as cleanup_e: | |
| print(f"Emergency OOM cleanup failed: {cleanup_e}") | |
| # Append an error message to the original message history | |
| error_ai_message = AIMessage(content=error_message_content) | |
| final_messages_on_error = original_messages + [error_ai_message] | |
| new_state["messages"] = final_messages_on_error | |
| new_state["done"] = True # Mark as done to prevent loops on LLM failure | |
| finally: | |
| try: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"🧹 Post-LLM CUDA cache cleared. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") | |
| except Exception: | |
| pass # Avoid error in cleanup hiding the main error | |
| return new_state | |
| import re | |
| import uuid | |
| def parse_react_output(state: AgentState) -> AgentState: | |
| print("Running parse_react_output (Action prioritized)...") | |
| messages = state["messages"] | |
| last_message = messages[-1] | |
| new_state = state.copy() | |
| # Only process AI messages (not system/user) | |
| if not isinstance(last_message, AIMessage): | |
| return new_state | |
| content = last_message.content | |
| # Remove any system prompt/instructions (if present in content) | |
| # Assume that the actual AI output is after the last occurrence of "You are a general AI assistant" or similar system prompt marker | |
| sys_prompt_pattern = r"(You are a general AI assistant.*?)(?=\n\n|$)" | |
| content_wo_sys_prompt = re.sub(sys_prompt_pattern, '', content, flags=re.DOTALL | re.IGNORECASE).strip() | |
| # Find the last occurrence of FINAL ANSWER or Action Input | |
| final_answer_match = list(re.finditer(r"FINAL ANSWER:", content_wo_sys_prompt, re.IGNORECASE)) | |
| action_input_match = list(re.finditer(r"Action Input:", content_wo_sys_prompt, re.IGNORECASE)) | |
| # Helper: get the last match position and which it was | |
| last_marker = None | |
| last_pos = -1 | |
| if final_answer_match: | |
| last_fa = final_answer_match[-1] | |
| last_marker = 'FINAL ANSWER' | |
| last_pos = last_fa.start() | |
| if action_input_match: | |
| last_ai = action_input_match[-1] | |
| if last_ai.start() > last_pos: | |
| last_marker = 'Action Input' | |
| last_pos = last_ai.start() | |
| # If neither marker found, mark as done | |
| if not last_marker: | |
| print("No FINAL ANSWER or Action Input found in last AI output.") | |
| new_state["done"] = True | |
| return new_state | |
| # Get the substring from the last marker to the end | |
| last_section = content_wo_sys_prompt[last_pos:].strip() | |
| # 2. If FINAL ANSWER is in the last part, end the process | |
| if last_marker == 'FINAL ANSWER': | |
| # Extract the answer after FINAL ANSWER: | |
| answer = re.search(r"FINAL ANSWER:\s*(.+)", last_section, re.IGNORECASE) | |
| final_answer_text = answer.group(1).strip() if answer else "" | |
| updated_ai_message = AIMessage(content=f"FINAL ANSWER: {final_answer_text}", tool_calls=[]) | |
| new_state["messages"] = messages[:-1] + [updated_ai_message] | |
| new_state["done"] = True | |
| print(f"FINAL ANSWER found at end: '{final_answer_text}'") | |
| return new_state | |
| # 3. If Action Input is in the last part, launch tool | |
| if last_marker == 'Action Input': | |
| # Try to extract the Action and Action Input for the last occurrence | |
| action_match = list(re.finditer(r"Action:\s*([^\n]+)", last_section)) | |
| action_input_match = list(re.finditer(r"Action Input:\s*([^\n]+)", last_section)) | |
| if action_match and action_input_match: | |
| tool_name = action_match[-1].group(1).strip() | |
| tool_input_raw = action_input_match[-1].group(1).strip() | |
| print(f"ReAct: Found Action: {tool_name}, Input: '{tool_input_raw}'") | |
| # Format tool_args as in your original code (simplified here) | |
| tool_args = {"query": tool_input_raw} | |
| tool_call_id = str(uuid.uuid4()) | |
| parsed_tool_calls = [{"name": tool_name, "args": tool_args, "id": tool_call_id}] | |
| updated_ai_message = AIMessage(content=content, tool_calls=parsed_tool_calls) | |
| new_state["messages"] = messages[:-1] + [updated_ai_message] | |
| new_state.pop("done", None) | |
| print(f"AIMessage updated with tool_calls: {parsed_tool_calls}") | |
| return new_state | |
| else: | |
| print("Action Input found at end, but could not parse Action or Action Input.") | |
| new_state["done"] = True | |
| return new_state | |
| # Fallback: mark as done | |
| print("No actionable marker found at end of last AI output. Marking as done.") | |
| new_state["done"] = True | |
| return new_state | |
| def download_youtube_video(url, output_dir='/tmp/video/', output_filename='downloaded_video.mp4'): | |
| """Download a YouTube video using yt-dlp""" | |
| # Ensure the output directory exists | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Delete all files in the output directory | |
| files = glob.glob(os.path.join(output_dir, '*')) | |
| for f in files: | |
| try: | |
| os.remove(f) | |
| except Exception as e: | |
| print(f"Error deleting {f}: {str(e)}") | |
| # Set output path for yt-dlp | |
| output_path = os.path.join(output_dir, output_filename) | |
| try: | |
| ydl_opts = { | |
| 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best', | |
| 'outtmpl': output_path, | |
| 'quiet': True, | |
| 'merge_output_format': 'mp4', # Ensures merged output is mp4 | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegVideoConvertor', | |
| 'preferedformat': 'mp4', # Recode if needed | |
| }] | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([url]) | |
| return output_path | |
| except Exception as e: | |
| print(f"Error downloading YouTube video: {str(e)}") | |
| return None | |
| def extract_frames(video_path, output_dir, frame_interval_seconds=10): | |
| """Extract frames from a video file at specified intervals""" | |
| # Clean output directory before extracting new frames | |
| if os.path.exists(output_dir): | |
| for filename in os.listdir(output_dir): | |
| file_path = os.path.join(output_dir, filename) | |
| try: | |
| if os.path.isfile(file_path) or os.path.islink(file_path): | |
| os.unlink(file_path) | |
| elif os.path.isdir(file_path): | |
| shutil.rmtree(file_path) | |
| except Exception as e: | |
| print(f'Failed to delete {file_path}. Reason: {e}') | |
| else: | |
| os.makedirs(output_dir, exist_ok=True) | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| print("Error: Could not open video.") | |
| return False | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_interval = int(fps * frame_interval_seconds) | |
| count = 0 | |
| saved = 0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if count % frame_interval == 0: | |
| frame_filename = os.path.join(output_dir, f"frame_{count:06d}.jpg") | |
| cv2.imwrite(frame_filename, frame) | |
| saved += 1 | |
| count += 1 | |
| cap.release() | |
| print(f"Extracted {saved} frames.") | |
| return saved > 0 | |
| except Exception as e: | |
| print(f"Exception during frame extraction: {e}") | |
| return False | |
| def answer_question_on_frame(image_path, question): | |
| """Answer a question about a single video frame using BLIP""" | |
| try: | |
| vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly | |
| processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used | |
| model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used | |
| device = "cpu" | |
| image = Image.open(image_path).convert('RGB') | |
| inputs = processor_vqa(image, question, return_tensors="pt").to(device) | |
| out = model_vqa.generate(**inputs) | |
| answer = processor_vqa.decode(out[0], skip_special_tokens=True) | |
| return answer | |
| except Exception as e: | |
| print(f"Error processing frame {image_path}: {str(e)}") | |
| return "Error processing this frame" | |
| def answer_video_question(frames_dir, question): | |
| """Answer a question about a video by analyzing extracted frames""" | |
| valid_exts = ('.jpg', '.jpeg', '.png') | |
| # Check if directory exists | |
| if not os.path.exists(frames_dir): | |
| return { | |
| "most_common_answer": "No frames found to analyze.", | |
| "all_answers": [], | |
| "answer_counts": Counter() | |
| } | |
| frame_files = [os.path.join(frames_dir, f) for f in os.listdir(frames_dir) | |
| if f.lower().endswith(valid_exts)] | |
| # Sort frames properly by number | |
| def get_frame_number(filename): | |
| match = re.search(r'(\d+)', os.path.basename(filename)) | |
| return int(match.group(1)) if match else 0 | |
| frame_files = sorted(frame_files, key=get_frame_number) | |
| if not frame_files: | |
| return { | |
| "most_common_answer": "No valid image frames found.", | |
| "all_answers": [], | |
| "answer_counts": Counter() | |
| } | |
| answers = [] | |
| for frame_path in frame_files: | |
| try: | |
| ans = answer_question_on_frame(frame_path, question) | |
| answers.append(ans) | |
| print(f"Processed frame: {os.path.basename(frame_path)}, Answer: {ans}") | |
| except Exception as e: | |
| print(f"Error processing frame {frame_path}: {str(e)}") | |
| if not answers: | |
| return { | |
| "most_common_answer": "Could not analyze any frames successfully.", | |
| "all_answers": [], | |
| "answer_counts": Counter() | |
| } | |
| counted = Counter(answers) | |
| most_common_answer, freq = counted.most_common(1)[0] | |
| return { | |
| "most_common_answer": most_common_answer, | |
| "all_answers": answers, | |
| "answer_counts": counted | |
| } | |
| class YoutubeScreenshotQA(BaseTool): | |
| name: str = "youtube_screenshot_qa" | |
| description: str = ( | |
| "Downloads a YouTube video, extracts screenshots at intervals, " | |
| "and answers a question about the video based on the screenshots. " | |
| "Input should be a dict with keys: 'youtube_url' and 'question'." | |
| "Example input: {'youtube_url': 'https://www.youtube.com/watch?v=L1vXCYZAYYM', 'question': 'What is the highest number of bird species on camera simultaneously?'}" | |
| ) | |
| frame_interval_seconds: int = 10 # Can be parameterized if needed | |
| def _run(self, input_data: Dict[str, Any]) -> str: | |
| youtube_url = input_data.get("youtube_url") | |
| question = input_data.get("question") | |
| if not youtube_url or not question: | |
| return "Error: Input must include 'youtube_url' and 'question'." | |
| # Step 1: Download the video | |
| video_dir = '/tmp/video/' | |
| video_filename = 'downloaded_video.mp4' | |
| print(f"Downloading YouTube video from {youtube_url}...") | |
| video_path = download_youtube_video(youtube_url, output_dir=video_dir, output_filename=video_filename) | |
| if not video_path or not os.path.exists(video_path): | |
| return "Error: Failed to download the YouTube video." | |
| # Step 2: Extract frames | |
| frames_dir = '/tmp/video_frames/' | |
| print(f"Extracting frames from {video_path} every {self.frame_interval_seconds} seconds...") | |
| success = extract_frames(video_path, frames_dir, frame_interval_seconds=self.frame_interval_seconds) | |
| if not success: | |
| return "Error: Failed to extract frames from the video." | |
| # Step 3: Analyze frames and answer question | |
| print(f"Answering question about the video frames...") | |
| answer_result = answer_video_question(frames_dir, question) | |
| if not answer_result or not answer_result.get("most_common_answer"): | |
| return "Error: Could not analyze video frames to answer the question." | |
| # Format the result | |
| most_common = answer_result["most_common_answer"] | |
| all_answers = answer_result["all_answers"] | |
| counts = answer_result["answer_counts"] | |
| result = ( | |
| f"Most common answer: {most_common}\n" | |
| f"All answers: {all_answers}\n" | |
| f"Answer counts: {dict(counts)}" | |
| ) | |
| return result | |
| def tools_condition_with_logging(state: AgentState): | |
| """ | |
| Custom tools condition function that checks if the last message contains tool calls | |
| in the Thought/Action/Action Input format and logs the transition decision. | |
| Args: | |
| state (AgentState): The current state containing messages | |
| Returns: | |
| str: "tools" if tool calls are present, "__end__" otherwise | |
| """ | |
| import re | |
| # Ensure we have messages in the state | |
| if not state.get("messages") or len(state["messages"]) == 0: | |
| print("❌ No messages found in state, ending conversation") | |
| return "__end__" | |
| # Get the last message | |
| last_message = state["messages"][-1] | |
| # Get message content | |
| content = "" | |
| if hasattr(last_message, 'content'): | |
| content = str(last_message.content) | |
| elif isinstance(last_message, dict) and 'content' in last_message: | |
| content = str(last_message['content']) | |
| else: | |
| print("❌ No content found in last message, ending conversation") | |
| return "__end__" | |
| print(f"🔍 Analyzing message content: {content[:200]}...") | |
| # Check for Thought/Action/Action Input format | |
| has_tool_calls = False | |
| # Pattern to match the format: | |
| # Thought: <thought> | |
| # Action: <tool_name> | |
| # Action Input: <input> | |
| thought_action_pattern = re.compile( | |
| r'Thought:\s*(.*?)\n\s*Action:\s*(.*?)\n\s*Action Input:\s*(.*?)(?:\n|$)', | |
| re.DOTALL | re.IGNORECASE | |
| ) | |
| # Also check for just Action/Action Input without Thought | |
| action_only_pattern = re.compile( | |
| r'Action:\s*(.*?)\n\s*Action Input:\s*(.*?)(?:\n|$)', | |
| re.DOTALL | re.IGNORECASE | |
| ) | |
| # Look for the complete format first | |
| match = thought_action_pattern.search(content) | |
| if not match: | |
| # Try the action-only format | |
| match = action_only_pattern.search(content) | |
| if match: | |
| thought = "No thought provided" | |
| action = match.group(1).strip() | |
| action_input = match.group(2).strip() | |
| else: | |
| action = None | |
| action_input = None | |
| thought = None | |
| else: | |
| thought = match.group(1).strip() | |
| action = match.group(2).strip() | |
| action_input = match.group(3).strip() | |
| if match and action: | |
| has_tool_calls = True | |
| print(f"🔧 Found tool call format:") | |
| print(f" Thought: {thought}") | |
| print(f" Action: {action}") | |
| print(f" Action Input: {action_input}") | |
| # Map common tool names to your actual tools | |
| tool_mappings = { | |
| 'wikipedia_semantic_search': 'wikipedia_tool', | |
| 'wikipedia': 'wikipedia_tool', | |
| 'search': 'search_tool', | |
| 'duckduckgo_search': 'search_tool', | |
| 'web_search': 'search_tool', | |
| 'youtube_screenshot_qa_tool': 'youtube_tool', | |
| 'youtube': 'youtube_tool', | |
| } | |
| # Normalize the action name | |
| normalized_action = action.lower().strip() | |
| # Store the parsed tool call information in the state for the tools node to use | |
| if 'parsed_tool_calls' not in state: | |
| state['parsed_tool_calls'] = [] | |
| tool_call_info = { | |
| 'thought': thought, | |
| 'action': action, | |
| 'action_input': action_input, | |
| 'normalized_action': normalized_action, | |
| 'tool_mapping': tool_mappings.get(normalized_action, normalized_action) | |
| } | |
| state['parsed_tool_calls'].append(tool_call_info) | |
| print(f"🚀 Added tool call to state: {tool_call_info}") | |
| # Don't execute tools here - let call_tool handle execution | |
| # Just store the parsed information for call_tool to use | |
| # Also check for standalone tool mentions (fallback) | |
| if not has_tool_calls: | |
| # Check for tool names mentioned in content | |
| tool_keywords = [ | |
| 'wikipedia_semantic_search', 'wikipedia', 'search', 'duckduckgo', | |
| 'youtube_screenshot_qa_tool', 'youtube', 'web search' | |
| ] | |
| content_lower = content.lower() | |
| for keyword in tool_keywords: | |
| if keyword in content_lower: | |
| print(f"🔧 Found tool keyword '{keyword}' in content (fallback detection)") | |
| has_tool_calls = True | |
| break | |
| if has_tool_calls: | |
| print("🔧 Tool calls detected, transitioning to tools...") | |
| return "tools" | |
| else: | |
| print("✅ No tool calls found, ending conversation") | |
| return "__end__" | |
| # 2. Improved call_tool with memory management | |
| def call_tool_with_memory_management(state: AgentState) -> AgentState: | |
| """Process tool calls with memory management.""" | |
| print("Running call_tool with memory management...") | |
| # Clear CUDA cache before processing | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| print(f"🧹 Cleared CUDA cache. Memory: {torch.cuda.memory_allocated()/1024**2:.1f}MB") | |
| except: | |
| pass | |
| # Check if we have parsed tool calls from the condition function | |
| if 'parsed_tool_calls' in state and state['parsed_tool_calls']: | |
| return execute_parsed_tool_calls(state) | |
| # Fallback to original OpenAI-style tool calls handling | |
| messages = state["messages"] | |
| last_message = messages[-1] | |
| if not hasattr(last_message, "tool_calls") or not last_message.tool_calls: | |
| print("No tool calls found in last message") | |
| return state | |
| # Copy the messages to avoid mutating the original list | |
| new_messages = list(messages) | |
| print(f"Processing {len(last_message.tool_calls)} tool calls") | |
| for i, tool_call in enumerate(last_message.tool_calls): | |
| print(f"Processing tool call {i+1}: {tool_call['name'] if isinstance(tool_call, dict) else tool_call.name}") | |
| # Handle both dict and object-style tool calls | |
| if isinstance(tool_call, dict): | |
| tool_name = tool_call.get("name", "") | |
| args = tool_call.get("args", {}) | |
| tool_call_id = tool_call.get("id", str(uuid.uuid4())) | |
| else: | |
| tool_name = getattr(tool_call, "name", "") | |
| args = getattr(tool_call, "args", {}) | |
| tool_call_id = getattr(tool_call, "id", str(uuid.uuid4())) | |
| # Find the matching tool | |
| selected_tool = None | |
| for tool in tools: | |
| if tool.name.lower() == tool_name.lower(): | |
| selected_tool = tool | |
| break | |
| if not selected_tool: | |
| tool_result = f"Error: Tool '{tool_name}' not found. Available tools: {', '.join(t.name for t in tools)}" | |
| print(f"Tool not found: {tool_name}") | |
| else: | |
| try: | |
| # Extract query | |
| if isinstance(args, dict) and "query" in args: | |
| query = args["query"] | |
| else: | |
| query = str(args) if args else "" | |
| print(f"Executing {tool_name} with query: {query[:100]}...") | |
| tool_result = selected_tool.run(query) | |
| # Aggressive truncation to prevent memory issues | |
| max_length = 3000 if "wikipedia" in tool_name.lower() else 2000 | |
| if len(tool_result) > max_length: | |
| tool_result = tool_result[:max_length] + f"... [Result truncated from {len(tool_result)} to {max_length} chars to prevent memory issues]" | |
| print(f"📄 Truncated result to {max_length} characters") | |
| print(f"Tool result length: {len(tool_result)} characters") | |
| except Exception as e: | |
| tool_result = f"Error executing tool '{tool_name}': {str(e)}" | |
| print(f"Tool execution error: {e}") | |
| # Create tool message | |
| tool_message = ToolMessage( | |
| content=tool_result, | |
| name=tool_name, | |
| tool_call_id=tool_call_id | |
| ) | |
| new_messages.append(tool_message) | |
| print(f"Added tool message for {tool_name}") | |
| # Update the state | |
| new_state = state.copy() | |
| new_state["messages"] = new_messages | |
| # Clear CUDA cache after processing | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| except: | |
| pass | |
| return new_state | |
| def execute_parsed_tool_calls(state: AgentState): | |
| """ | |
| Execute tool calls that were parsed from the Thought/Action/Action Input format. | |
| This is called by call_tool when parsed_tool_calls are present in state. | |
| Args: | |
| state (AgentState): The current state containing parsed tool calls | |
| Returns: | |
| AgentState: Updated state with tool results | |
| """ | |
| # Use the same tools list that's available globally | |
| # Map tool names to the actual tool instances | |
| tool_name_mappings = { | |
| 'wikipedia_semantic_search': 'wikipedia_tool', | |
| 'wikipedia': 'wikipedia_tool', | |
| 'search': 'enhanced_search', # Updated mapping | |
| 'duckduckgo_search': 'enhanced_search', # Updated mapping | |
| 'web_search': 'enhanced_search', # Updated mapping | |
| 'enhanced_search': 'enhanced_search', # Direct mapping | |
| 'youtube_screenshot_qa_tool': 'youtube_tool', | |
| 'youtube': 'youtube_tool', | |
| } | |
| # Create a lookup by tool names for your existing tools list | |
| tools_by_name = {} | |
| for tool in tools: | |
| tools_by_name[tool.name.lower()] = tool | |
| # Also map by class name for flexibility | |
| class_name = tool.__class__.__name__.lower() | |
| if 'wikipedia' in class_name: | |
| tools_by_name['wikipedia_tool'] = tool | |
| elif 'search' in class_name or 'duck' in class_name: | |
| tools_by_name['search_tool'] = tool | |
| elif 'youtube' in class_name: | |
| tools_by_name['youtube_tool'] = tool | |
| # Copy messages to avoid mutation during iteration | |
| new_messages = list(state["messages"]) | |
| for tool_call in state['parsed_tool_calls']: | |
| action = tool_call['action'] | |
| action_input = tool_call['action_input'] | |
| thought = tool_call['thought'] | |
| normalized_action = tool_call['normalized_action'] | |
| print(f"🚀 Executing tool: {action} with input: {action_input}") | |
| # Find the tool instance | |
| tool_instance = None | |
| # Try direct name match first | |
| if normalized_action in tools_by_name: | |
| tool_instance = tools_by_name[normalized_action] | |
| # Try mapped name | |
| elif normalized_action in tool_name_mappings: | |
| mapped_name = tool_name_mappings[normalized_action] | |
| if mapped_name in tools_by_name: | |
| tool_instance = tools_by_name[mapped_name] | |
| if tool_instance: | |
| try: | |
| result = tool_instance.run(action_input) | |
| if len(result) > 6000: | |
| result = result[:6000] + "... [Result truncated due to length]" | |
| # Create observation message in the format your agent expects | |
| from langchain_core.messages import AIMessage | |
| observation = f"Observation: {result}" | |
| observation_message = AIMessage(content=observation) | |
| new_messages.append(observation_message) | |
| print(f"✅ Tool '{action}' executed successfully") | |
| except Exception as e: | |
| print(f"❌ Error executing tool '{action}': {e}") | |
| from langchain_core.messages import AIMessage | |
| error_msg = f"Observation: Error executing '{action}': {str(e)}" | |
| error_message = AIMessage(content=error_msg) | |
| new_messages.append(error_message) | |
| else: | |
| print(f"❌ Tool '{action}' not found in available tools") | |
| available_tool_names = list(tools_by_name.keys()) | |
| from langchain_core.messages import AIMessage | |
| error_msg = f"Observation: Tool '{action}' not found. Available tools: {', '.join(available_tool_names)}" | |
| error_message = AIMessage(content=error_msg) | |
| new_messages.append(error_message) | |
| # Update state with new messages and clear parsed tool calls | |
| state["messages"] = new_messages | |
| state['parsed_tool_calls'] = [] | |
| return state | |
| # 1. Add loop detection to your AgentState | |
| def should_continue(state: AgentState) -> str: | |
| """Determine if the agent should continue or end.""" | |
| print("Running should_continue....") | |
| messages = state["messages"] | |
| #ipdb.set_trace() | |
| # Check if we're done | |
| if state.get("done", False): | |
| return "end" | |
| # Prevent infinite loops - limit tool calls | |
| tool_call_count = sum(1 for msg in messages if hasattr(msg, 'tool_calls') and msg.tool_calls) | |
| if tool_call_count >= 3: # Max 3 tool calls per conversation | |
| print(f"⚠️ Stopping: Too many tool calls ({tool_call_count})") | |
| return "end" | |
| # Check for repeated tool calls with same query | |
| recent_tool_calls = [] | |
| for msg in messages[-6:]: # Check last 6 messages | |
| if hasattr(msg, 'tool_calls') and msg.tool_calls: | |
| for tool_call in msg.tool_calls: | |
| if isinstance(tool_call, dict): | |
| recent_tool_calls.append((tool_call.get('name'), str(tool_call.get('args', {})))) | |
| if len(recent_tool_calls) >= 2 and recent_tool_calls[-1] == recent_tool_calls[-2]: | |
| print("⚠️ Stopping: Repeated tool call detected") | |
| return "end" | |
| # Check message count to prevent runaway conversations | |
| if len(messages) > 15: | |
| print(f"⚠️ Stopping: Too many messages ({len(messages)})") | |
| return "end" | |
| return "continue" | |
| def route_after_parse_react(state: AgentState) -> str: | |
| """Determines the next step after parsing LLM output, prioritizing end state.""" | |
| if state.get("done", False): # Check if parse_react_output decided we are done | |
| return "end_processing" | |
| # Original logic: check for tool calls in the last message | |
| # Ensure messages list and last message exist before checking tool_calls | |
| messages = state.get("messages", []) | |
| if messages: | |
| last_message = messages[-1] | |
| if hasattr(last_message, 'tool_calls') and last_message.tool_calls: | |
| return "call_tool" | |
| return "call_llm" | |
| #wikipedia_tool = WikipediaSearchToolWithFAISS() | |
| #search_tool = DuckDuckGoSearchRun() | |
| #youtube_screenshot_qa_tool = YoutubeScreenshotQA() | |
| # Combine all tools | |
| #tools = [wikipedia_tool, search_tool, youtube_screenshot_qa_tool] | |
| # Update your tools list to use the global instances | |
| # | |
| # --- Graph Construction --- | |
| # --- Graph Construction --- | |
| def create_memory_safe_workflow(): | |
| """Create a workflow with memory management and loop prevention.""" | |
| # These models are initialized here but might be better managed if they need to be released/reinitialized | |
| # like you attempt in run_agent. Consider passing them or managing their lifecycle carefully. | |
| hf_pipe = create_llm_pipeline() | |
| llm = HuggingFacePipeline(pipeline=hf_pipe) | |
| # vqa_model_name = "Salesforce/blip-vqa-base" # Not used in the provided graph logic directly | |
| # processor_vqa = BlipProcessor.from_pretrained(vqa_model_name) # Not used | |
| # model_vqa = BlipForQuestionAnswering.from_pretrained(vqa_model_name).to('cpu') # Not used | |
| workflow = StateGraph(AgentState) | |
| # Bind the llm_model to the call_llm_with_memory_management function | |
| bound_call_llm = partial(call_llm_with_memory_management, llm_model=llm) | |
| # Add nodes with memory-safe versions | |
| workflow.add_node("call_llm", bound_call_llm) # Use the bound version here | |
| workflow.add_node("parse_react_output", parse_react_output) | |
| workflow.add_node("call_tool", call_tool_with_memory_management) # Ensure this doesn't also need llm if it calls back directly | |
| # Set entry point | |
| workflow.set_entry_point("call_llm") | |
| # Add conditional edges | |
| workflow.add_conditional_edges( | |
| "call_llm", | |
| should_continue, | |
| { | |
| "continue": "parse_react_output", | |
| "end": END | |
| } | |
| ) | |
| workflow.add_conditional_edges( | |
| "parse_react_output", | |
| route_after_parse_react, | |
| { | |
| "call_tool": "call_tool", | |
| "call_llm": "call_llm", | |
| "end_processing": END | |
| } | |
| ) | |
| workflow.add_edge("call_tool", "call_llm") | |
| return workflow.compile() | |
| # --- Run the Agent --- | |
| def run_agent(myagent, state: AgentState): | |
| """ | |
| Initialize agent with proper system message and formatted query. | |
| """ | |
| #global llm, hf_pipe, model_vqa, processor_vqa | |
| global WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL, tools | |
| #ipdb.set_trace() | |
| # At the module level, create instances once | |
| WIKIPEDIA_TOOL = WikipediaSearchToolWithFAISS() | |
| SEARCH_TOOL = EnhancedDuckDuckGoSearchTool(max_results=3, max_chars_per_page=3000) | |
| YOUTUBE_TOOL = YoutubeScreenshotQA() | |
| tools = [WIKIPEDIA_TOOL, SEARCH_TOOL, YOUTUBE_TOOL] | |
| # Create a fresh system message each time | |
| formatted_tools_description = render_text_description(tools) | |
| current_date_str = datetime.now().strftime("%Y-%m-%d") | |
| system_content = f"""You are a general AI assistant. with access to these tools: | |
| {formatted_tools_description} | |
| If you need the most current information as of 2025, use enhanced_search | |
| If you need to do in-depth research, use wikipedia_semantic_search_all_candidates_strong_entity_priority_list_retrieval | |
| If you can answer the question confidently, do so directly. | |
| If the question seems like gibberish (not English), try flipping the entire question and re-read the question. | |
| If you need more information, use a tool. | |
| (Think through the problem step by step) | |
| When using a tool, follow this format: | |
| Thought: <your thought> | |
| Action: <tool_name> | |
| Action Input: <tool_input> | |
| Only use the tools listed above for the Action: step. Do not invent new tool names or actions. If you need to reason, do so in the Thought: step. After using a tool, process its output in your Thought: step, not as an Action. | |
| Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. | |
| YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. | |
| If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. | |
| If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. | |
| If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string | |
| Do not provide disclaimers. | |
| Do not provide supporting details. | |
| """ | |
| # Get user question from AgentState | |
| query = state['question'] | |
| # Pattern for YouTube | |
| yt_pattern = r"(https?://)?(www\.)?(youtube\.com|youtu\.be)/[^\s]+" | |
| has_youtube = re.search(yt_pattern, query) is not None | |
| if has_youtube: | |
| # Store the extracted YouTube URL in the state | |
| url_match = re.search(r"(https?://[^\s]+)", query) | |
| if url_match: | |
| state['youtube_url'] = url_match.group(0) | |
| # Format the user query to guide the model better | |
| formatted_query = f"""{query}""" | |
| # Initialize agent state with proper message types | |
| system_message = SystemMessage(content=system_content) | |
| human_message = HumanMessage(content=formatted_query) | |
| # Initialize state with properly typed messages and done=False | |
| # state = {"messages": [system_message, human_message], "done": False} | |
| state['messages'] = [system_message, human_message] | |
| state["done"] = False | |
| # Use the new method to run the graph | |
| result = myagent.invoke(state) | |
| # Check if FINAL ANSWER was given (i.e., workflow ended) | |
| if result.get("done"): | |
| #del llm | |
| #del hf_pipe | |
| #del model_vqa | |
| #del processor_vqa | |
| torch.cuda.empty_cache() | |
| torch.cuda.ipc_collect() | |
| gc.collect() | |
| print("Released GPU memory after FINAL ANSWER.") | |
| # Re-initialize for the next run | |
| #hf_pipe = create_llm_pipeline() | |
| #llm = HuggingFacePipeline(pipeline=hf_pipe) | |
| #print("Re-initilized llm...") | |
| # Extract and return just the messages for cleaner output | |
| return result["messages"] | |