pilgrim-65's picture
First commit
655b11f
import os
from io import StringIO
import re
import base64
from langchain_core.tools import tool
from langchain_tavily import TavilySearch
from langchain_experimental.utilities import PythonREPL
from langchain_community.retrievers import WikipediaRetriever
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from typing import List
import wikipedia
from bs4 import BeautifulSoup, Tag
import json
import pandas as pd
from logging_config import logger # Import the shared logger
from dotenv import load_dotenv
load_dotenv()
@tool
def python_tool(code: str) -> str:
"""A Python shell. Use this to execute python commands.
Input should be an str with a valid python script.
If you want to see the output of a value,
you should print it out with `print(...)`."""
logger.info(f"Invoking Python REPL tool{code!r}")
repl = PythonREPL()
try:
# print("Running the Python REPL tool")
# print(code)
result = repl.run(code)
except BaseException as e:
return f"Failed to execute. Error: {e!r}"
return f"Result of code execution: {result}"
@tool
def reverse_tool(question: str) -> str:
"""Reverses the input string."""
logger.info(f"Invoking reverse tool with question: {question!r}")
return question[::-1]
@tool
def excel_file_to_markdown(task_id):
"""Given a task_id corresponding to an Excel file,
fetch the file and convert its content to markdown format."""
import pandas as pd
path = f"./files/{task_id}.xlsx"
df = pd.read_excel(path)
logger.info(f"Converted Excel file {path} to markdown")
return df.to_markdown()
@tool
def sum_numbers(all_numbers: List[float]) -> float:
"""
Sums a list of numbers and returns the result.
Args:
all_numbers ('list' of float): A list of numbers.
"""
logger.info(f"Summing numbers: {all_numbers}")
numbers_list = [float(x) for x in all_numbers]
result = sum(numbers_list)
return result
@tool
def web_search(question: str) -> str: # Tool expects arguments, not the whole state
"""Perform a web search using TavilySearch and return relevant documents.
Args:
question (str): The query for the web search.
Returns:
web_search_result (str): The result of the web search.
"""
logger.info(f"Performing web search for query: {question}")
web_tool = TavilySearch(chunks_per_source=3,
max_results=3,
include_answer=True,
include_raw_content="markdown",
search_depth="advanced"
)
try:
search_results = web_tool.invoke(question)
logger.info(f"Web search completed with {len(search_results.get('results', []))} results")
if search_results.get('answer'):
logger.info(f"Web search answer length: {len(search_results['answer'])}")
return search_results['answer'] # type: ignore
retrieved_docs = [{"url": sr.get('url', ""), "content": sr.get('content', "")} \
for sr in search_results.get('results', [])]
web_search_result = json.dumps(retrieved_docs, indent=2)
return web_search_result # type: ignore
except Exception as e:
logger.error(f"Web search failed: {e}")
# Return an empty list or specific error document if the search fails
return f"Web search failed: {e}"
# This tool is not needed for the assignment???
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for query and return maximum 2 results
Args:
query (str): query to search on Wikipedia
Returns:
wiki_result (str): result of search
"""
try:
retriever = WikipediaRetriever(top_k_results=2, doc_content_chars_max=20000) # type: ignore
docs = retriever.invoke(query)
wiki_result = "\n".join([f"- {doc.page_content} (source: {doc.metadata.get('source', 'unknown')})" for doc in docs])
url = docs[0].metadata.get('source', 'unknown') if docs else 'unknown'
logger.info(f"Wikipedia search completed for query: {query} with length {len(wiki_result)}")
return wiki_result # type: ignore
except Exception as e:
return f"wiki_search failed {e}"
@tool
def get_wikipedia_info(query: str) -> str:
"""
Fetches and parses all HTML tables and their preceding Hx headers
from a given Wikipedia page.
Use this to get structured data from Wikipedia pages, such as lists of items,
tables of statistics, discographies, etc.
Args:
query (str): The query to search on Wikipedia.
Returns:
formatted_output (str): a string representation of the structured data,
formatted in a Markdown-like style.
"""
logger.info(f"Tool get_wikipedia_info invoked with query: {query!r}")
try:
page_title = wikipedia.search(query, results=1)[0]
page_content = wikipedia.page(page_title, auto_suggest=False).html()
logger.info(f"Fetching Wikipedia page for title: {page_title!r}")
soup = BeautifulSoup(page_content, 'html.parser')
# main_content = soup.find('div', {'id': 'mw-content-text'})
# if not main_content:
# return "Could not find the main content area on the page."
# Compile a regular expression for h1 to h6 tags
heading_pattern = re.compile('^h[1-6]$')
# Find all headings and tables in one pass
elements = soup.find_all([heading_pattern, 'table'])
extracted_data = []
current_headers = {} # Using a dictionary for flexibility
for element in elements:
if isinstance(element, Tag):
if re.match(heading_pattern, element.name):
current_headers[element.name] = element.get_text().strip()
# Reset lower-level headers when a higher-level one is found
for i in range(int(element.name[1]) + 1, 7):
current_headers.pop(f'h{i}', None)
elif element.name == 'table' and 'wikitable' in element.get('class', []): # type: ignore
try:
df = pd.read_html(StringIO(str(element)))[0] # type: ignore
table_info = {
'headers': current_headers.copy(),
'table_data': df.to_markdown()
}
extracted_data.append(table_info)
except ValueError:
continue
if not extracted_data:
return "No 'wikitable' found on the specified page."
# Format the extracted data into a readable, markdown string
formatted_output = "### Extracted Tables with Headers\n\n"
for i, item in enumerate(extracted_data):
formatted_output += f"--- Table {i+1} ---\n"
# Sort headers by level (h1, h2, h3...) to ensure correct order
sorted_headers = sorted(item['headers'].items(), key=lambda x: int(x[0][1]))
for header_tag, header_text in sorted_headers:
header_level = len(header_tag)
formatted_output += f"{'#' * (header_level + 2)} {header_text}\n"
formatted_output += f"```\n{item['table_data']}\n```\n\n"
return formatted_output
except wikipedia.exceptions.PageError:
return "Wikipedia page not found."
except Exception as e:
return f"An error occurred: {e}"
@tool
def ask_audio_model(query: str, task_id: str) -> str:
"""
Processes an audio query by sending both a text prompt and an task_id
(associated with an audio file)
to a generative AI model, and returns the model's response.
Args:
query (str): The text prompt or question for the model.
task_id (str): The identifier used to load the audio file (MP3) in the downloaded files directory.
Returns:
str: The response generated by the AI model based on the provided text and audio.
"""
logger.info(f"audio_model called with query='{query[:30]}...'")
if "GOOGLE_API_KEY" not in os.environ:
os.environ["GOOGLE_API_KEY"] = os.environ["GEMINI_API_KEY"]
llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash-lite-preview-06-17",
temperature=0,
max_tokens=None,
timeout=60, # Added a timeout
max_retries=2,
)
audio_file_path = f"./files/{task_id}.mp3" # Assuming MP3 for a general use case
audio_mime_type = "audio/mpeg"
with open(audio_file_path, "rb") as audio_file:
encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8")
message = HumanMessage(
content=[
{"type": "text", "text": query},
{
"type": "media",
"data": encoded_audio, # Use base64 string directly
"mime_type": audio_mime_type,
},
]
)
response = llm.invoke([message])
logger.info(f"ask_audio_model metadata = {response.usage_metadata}") # type: ignore
return response.content # type: ignore