File size: 9,339 Bytes
655b11f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import os
from io import StringIO
import re
import base64
from langchain_core.tools import tool
from langchain_tavily import TavilySearch
from langchain_experimental.utilities import PythonREPL
from langchain_community.retrievers import WikipediaRetriever
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import HumanMessage
from typing import List
import wikipedia
from bs4 import BeautifulSoup, Tag
import json
import pandas as pd
from logging_config import logger  # Import the shared logger
from dotenv import load_dotenv
load_dotenv()

@tool
def python_tool(code: str) -> str:
    """A Python shell. Use this to execute python commands. 
    Input should be an str with a valid python script. 
    If you want to see the output of a value, 
    you should print it out with `print(...)`."""

    logger.info(f"Invoking Python REPL tool{code!r}")
    repl = PythonREPL()
    try:
        # print("Running the Python REPL tool")
        # print(code)
        result = repl.run(code)
    except BaseException as e:
        return f"Failed to execute. Error: {e!r}"
    return f"Result of code execution: {result}"

@tool
def reverse_tool(question: str) -> str:
    """Reverses the input string."""
    logger.info(f"Invoking reverse tool with question: {question!r}")
    return question[::-1]

@tool
def excel_file_to_markdown(task_id):
    """Given a task_id corresponding to an Excel file,
    fetch the file and convert its content to markdown format."""
    import pandas as pd

    path = f"./files/{task_id}.xlsx"
    df = pd.read_excel(path)
    logger.info(f"Converted Excel file {path} to markdown")
    return df.to_markdown()

@tool
def sum_numbers(all_numbers: List[float]) -> float:
    """
    Sums a list of numbers and returns the result.

    Args:
        all_numbers ('list' of float): A list of numbers.

    """
    logger.info(f"Summing numbers: {all_numbers}")
    numbers_list = [float(x) for x in all_numbers]
    result = sum(numbers_list)
    return result

@tool
def web_search(question: str) -> str: # Tool expects arguments, not the whole state
    """Perform a web search using TavilySearch and return relevant documents.
    Args:
        question (str): The query for the web search.
    Returns:
        web_search_result (str): The result of the web search.
    """
    logger.info(f"Performing web search for query: {question}")

    web_tool = TavilySearch(chunks_per_source=3,
                            max_results=3, 
                            include_answer=True,
                            include_raw_content="markdown",
                            search_depth="advanced"
                            )
    try:
        search_results = web_tool.invoke(question)
        logger.info(f"Web search completed with {len(search_results.get('results', []))} results")
        if search_results.get('answer'):
            logger.info(f"Web search answer length: {len(search_results['answer'])}")
            return search_results['answer'] # type: ignore       
        retrieved_docs = [{"url": sr.get('url', ""), "content": sr.get('content', "")} \
                          for sr in search_results.get('results', [])]
        web_search_result = json.dumps(retrieved_docs, indent=2)
        return web_search_result # type: ignore
    
    except Exception as e:
        logger.error(f"Web search failed: {e}")
        # Return an empty list or specific error document if the search fails
        return f"Web search failed: {e}"

# This tool is not needed for the assignment???
@tool
def wiki_search(query: str) -> str:
    """Search Wikipedia for query and return maximum 2 results
    
    Args:
        query (str): query to search on Wikipedia
    Returns:
        wiki_result (str): result of search
    """
    try:
        retriever = WikipediaRetriever(top_k_results=2, doc_content_chars_max=20000) # type: ignore
        docs = retriever.invoke(query)
        wiki_result = "\n".join([f"- {doc.page_content} (source: {doc.metadata.get('source', 'unknown')})" for doc in docs])
        url = docs[0].metadata.get('source', 'unknown') if docs else 'unknown'

        logger.info(f"Wikipedia search completed for query: {query} with length {len(wiki_result)}")
        return wiki_result # type: ignore
    except Exception as e:
        return f"wiki_search failed {e}"

@tool
def get_wikipedia_info(query: str) -> str:
    """
    Fetches and parses all HTML tables and their preceding Hx headers
    from a given Wikipedia page. 
    Use this to get structured data from Wikipedia pages, such as lists of items, 
    tables of statistics, discographies, etc.
    Args:
        query (str): The query to search on Wikipedia.    
    Returns:
        formatted_output (str): a string representation of the structured data,
        formatted in a Markdown-like style.
    """
    logger.info(f"Tool get_wikipedia_info invoked with query: {query!r}")
    try:
        page_title = wikipedia.search(query, results=1)[0]
        page_content = wikipedia.page(page_title, auto_suggest=False).html()
        logger.info(f"Fetching Wikipedia page for title: {page_title!r}")
        soup = BeautifulSoup(page_content, 'html.parser')

        # main_content = soup.find('div', {'id': 'mw-content-text'})
        # if not main_content:
        #     return "Could not find the main content area on the page."

        # Compile a regular expression for h1 to h6 tags
        heading_pattern = re.compile('^h[1-6]$')
        
        # Find all headings and tables in one pass
        elements = soup.find_all([heading_pattern, 'table'])
        
        extracted_data = []
        current_headers = {} # Using a dictionary for flexibility
        
        for element in elements:
            if isinstance(element, Tag):
                if re.match(heading_pattern, element.name):
                    current_headers[element.name] = element.get_text().strip()
                    # Reset lower-level headers when a higher-level one is found
                    for i in range(int(element.name[1]) + 1, 7):
                        current_headers.pop(f'h{i}', None)
                elif element.name == 'table' and 'wikitable' in element.get('class', []): # type: ignore
                    try:
                        df = pd.read_html(StringIO(str(element)))[0] # type: ignore
                        table_info = {
                            'headers': current_headers.copy(),
                            'table_data': df.to_markdown()
                        }
                        extracted_data.append(table_info)
                    except ValueError:
                        continue
        
        if not extracted_data:
            return "No 'wikitable' found on the specified page."
        
        # Format the extracted data into a readable, markdown string
        formatted_output = "### Extracted Tables with Headers\n\n"
        
        for i, item in enumerate(extracted_data):
            formatted_output += f"--- Table {i+1} ---\n"
            
            # Sort headers by level (h1, h2, h3...) to ensure correct order
            sorted_headers = sorted(item['headers'].items(), key=lambda x: int(x[0][1]))
            
            for header_tag, header_text in sorted_headers:
                header_level = len(header_tag)
                formatted_output += f"{'#' * (header_level + 2)} {header_text}\n"
            
            formatted_output += f"```\n{item['table_data']}\n```\n\n"
            
        return formatted_output

    except wikipedia.exceptions.PageError:
        return "Wikipedia page not found."
    except Exception as e:
        return f"An error occurred: {e}"


@tool
def ask_audio_model(query: str, task_id: str) -> str:
    """
    Processes an audio query by sending both a text prompt and an task_id 
    (associated with an audio file)
    to a generative AI model, and returns the model's response.

    Args:
        query (str): The text prompt or question for the model.
        task_id (str): The identifier used to load the audio file (MP3) in the downloaded files directory.

    Returns:
        str: The response generated by the AI model based on the provided text and audio.
    """

    logger.info(f"audio_model called with query='{query[:30]}...'")

    if "GOOGLE_API_KEY" not in os.environ:
        os.environ["GOOGLE_API_KEY"] = os.environ["GEMINI_API_KEY"]

    llm = ChatGoogleGenerativeAI(
        model="gemini-2.5-flash-lite-preview-06-17",
        temperature=0,
        max_tokens=None,
        timeout=60,  # Added a timeout
        max_retries=2,
    )

    audio_file_path = f"./files/{task_id}.mp3" # Assuming MP3 for a general use case

    audio_mime_type = "audio/mpeg"

    with open(audio_file_path, "rb") as audio_file:
        encoded_audio = base64.b64encode(audio_file.read()).decode("utf-8")

    message = HumanMessage(
        content=[
            {"type": "text", "text": query},
            {
                "type": "media",
                "data": encoded_audio,  # Use base64 string directly
                "mime_type": audio_mime_type,
            },
        ]
    )
    response = llm.invoke([message])  
    logger.info(f"ask_audio_model metadata = {response.usage_metadata}") # type: ignore
    return response.content # type: ignore