Spaces:
Runtime error
Runtime error
File size: 9,339 Bytes
655b11f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 |
import os
from io import StringIO
import re
import base64
from langchain_core.tools import tool
from langchain_tavily import TavilySearch
from langchain_experimental.utilities import PythonREPL
from langchain_community.retrievers import WikipediaRetriever
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from typing import List
import wikipedia
from bs4 import BeautifulSoup, Tag
import json
import pandas as pd
from logging_config import logger # Import the shared logger
from dotenv import load_dotenv
load_dotenv()
@tool
def python_tool(code: str) -> str:
"""A Python shell. Use this to execute python commands.
Input should be an str with a valid python script.
If you want to see the output of a value,
you should print it out with `print(...)`."""
logger.info(f"Invoking Python REPL tool{code!r}")
repl = PythonREPL()
try:
# print("Running the Python REPL tool")
# print(code)
result = repl.run(code)
except BaseException as e:
return f"Failed to execute. Error: {e!r}"
return f"Result of code execution: {result}"
@tool
def reverse_tool(question: str) -> str:
"""Reverses the input string."""
logger.info(f"Invoking reverse tool with question: {question!r}")
return question[::-1]
@tool
def excel_file_to_markdown(task_id):
"""Given a task_id corresponding to an Excel file,
fetch the file and convert its content to markdown format."""
import pandas as pd
path = f"./files/{task_id}.xlsx"
df = pd.read_excel(path)
logger.info(f"Converted Excel file {path} to markdown")
return df.to_markdown()
@tool
def sum_numbers(all_numbers: List[float]) -> float:
"""
Sums a list of numbers and returns the result.
Args:
all_numbers ('list' of float): A list of numbers.
"""
logger.info(f"Summing numbers: {all_numbers}")
numbers_list = [float(x) for x in all_numbers]
result = sum(numbers_list)
return result
@tool
def web_search(question: str) -> str: # Tool expects arguments, not the whole state
"""Perform a web search using TavilySearch and return relevant documents.
Args:
question (str): The query for the web search.
Returns:
web_search_result (str): The result of the web search.
"""
logger.info(f"Performing web search for query: {question}")
web_tool = TavilySearch(chunks_per_source=3,
max_results=3,
include_answer=True,
include_raw_content="markdown",
search_depth="advanced"
)
try:
search_results = web_tool.invoke(question)
logger.info(f"Web search completed with {len(search_results.get('results', []))} results")
if search_results.get('answer'):
logger.info(f"Web search answer length: {len(search_results['answer'])}")
return search_results['answer'] # type: ignore
retrieved_docs = [{"url": sr.get('url', ""), "content": sr.get('content', "")} \
for sr in search_results.get('results', [])]
web_search_result = json.dumps(retrieved_docs, indent=2)
return web_search_result # type: ignore
except Exception as e:
logger.error(f"Web search failed: {e}")
# Return an empty list or specific error document if the search fails
return f"Web search failed: {e}"
# This tool is not needed for the assignment???
@tool
def wiki_search(query: str) -> str:
"""Search Wikipedia for query and return maximum 2 results
Args:
query (str): query to search on Wikipedia
Returns:
wiki_result (str): result of search
"""
try:
retriever = WikipediaRetriever(top_k_results=2, doc_content_chars_max=20000) # type: ignore
docs = retriever.invoke(query)
wiki_result = "\n".join([f"- {doc.page_content} (source: {doc.metadata.get('source', 'unknown')})" for doc in docs])
url = docs[0].metadata.get('source', 'unknown') if docs else 'unknown'
logger.info(f"Wikipedia search completed for query: {query} with length {len(wiki_result)}")
return wiki_result # type: ignore
except Exception as e:
return f"wiki_search failed {e}"
@tool
def get_wikipedia_info(query: str) -> str:
"""
Fetches and parses all HTML tables and their preceding Hx headers
from a given Wikipedia page.
Use this to get structured data from Wikipedia pages, such as lists of items,
tables of statistics, discographies, etc.
Args:
query (str): The query to search on Wikipedia.
Returns:
formatted_output (str): a string representation of the structured data,
formatted in a Markdown-like style.
"""
logger.info(f"Tool get_wikipedia_info invoked with query: {query!r}")
try:
page_title = wikipedia.search(query, results=1)[0]
page_content = wikipedia.page(page_title, auto_suggest=False).html()
logger.info(f"Fetching Wikipedia page for title: {page_title!r}")
soup = BeautifulSoup(page_content, 'html.parser')
# main_content = soup.find('div', {'id': 'mw-content-text'})
# if not main_content:
# return "Could not find the main content area on the page."
# Compile a regular expression for h1 to h6 tags
heading_pattern = re.compile('^h[1-6]$')
# Find all headings and tables in one pass
elements = soup.find_all([heading_pattern, 'table'])
extracted_data = []
current_headers = {} # Using a dictionary for flexibility
for element in elements:
if isinstance(element, Tag):
if re.match(heading_pattern, element.name):
current_headers[element.name] = element.get_text().strip()
# Reset lower-level headers when a higher-level one is found
for i in range(int(element.name[1]) + 1, 7):
current_headers.pop(f'h{i}', None)
elif element.name == 'table' and 'wikitable' in element.get('class', []): # type: ignore
try:
df = pd.read_html(StringIO(str(element)))[0] # type: ignore
table_info = {
'headers': current_headers.copy(),
'table_data': df.to_markdown()
}
extracted_data.append(table_info)
except ValueError:
continue
if not extracted_data:
return "No 'wikitable' found on the specified page."
# Format the extracted data into a readable, markdown string
formatted_output = "### Extracted Tables with Headers\n\n"
for i, item in enumerate(extracted_data):
formatted_output += f"--- Table {i+1} ---\n"
# Sort headers by level (h1, h2, h3...) to ensure correct order
sorted_headers = sorted(item['headers'].items(), key=lambda x: int(x[0][1]))
for header_tag, header_text in sorted_headers:
header_level = len(header_tag)
formatted_output += f"{'#' * (header_level + 2)} {header_text}\n"
formatted_output += f"```\n{item['table_data']}\n```\n\n"
return formatted_output
except wikipedia.exceptions.PageError:
return "Wikipedia page not found."
except Exception as e:
return f"An error occurred: {e}"
@tool
def ask_audio_model(query: str, task_id: str) -> str:
"""
Processes an audio query by sending both a text prompt and an task_id
(associated with an audio file)
to a generative AI model, and returns the model's response.
Args:
query (str): The text prompt or question for the model.
task_id (str): The identifier used to load the audio file (MP3) in the downloaded files directory.
Returns:
str: The response generated by the AI model based on the provided text and audio.
"""
logger.info(f"audio_model called with query='{query[:30]}...'")
if "GOOGLE_API_KEY" not in os.environ:
os.environ["GOOGLE_API_KEY"] = os.environ["GEMINI_API_KEY"]
llm = ChatGoogleGenerativeAI(
model="gemini-2.5-flash-lite-preview-06-17",
temperature=0,
max_tokens=None,
timeout=60, # Added a timeout
max_retries=2,
)
audio_file_path = f"./files/{task_id}.mp3" # Assuming MP3 for a general use case
audio_mime_type = "audio/mpeg"
with open(audio_file_path, "rb") as audio_file:
encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8")
message = HumanMessage(
content=[
{"type": "text", "text": query},
{
"type": "media",
"data": encoded_audio, # Use base64 string directly
"mime_type": audio_mime_type,
},
]
)
response = llm.invoke([message])
logger.info(f"ask_audio_model metadata = {response.usage_metadata}") # type: ignore
return response.content # type: ignore |