| | import os |
| | import pandas as pd |
| | import tempfile |
| | import typing |
| |
|
| | from base64 import b64encode |
| | from io import StringIO |
| |
|
| | import httpx |
| |
|
| | from anyio import Path |
| | from asyncer import asyncify |
| | from langchain_community.document_loaders import ArxivLoader |
| | from langchain_community.document_loaders import WikipediaLoader |
| | from langchain_core.messages import HumanMessage |
| | from langchain_tavily import TavilyExtract |
| | from langchain_tavily import TavilySearch |
| | from langgraph.prebuilt import create_react_agent |
| | from langgraph.prebuilt import InjectedState |
| | from langchain.tools import BaseTool |
| | from langchain.tools import tool |
| | from pydantic import Field |
| | from typing_extensions import Annotated |
| |
|
| | from utils import get_llm |
| | from config import GOOGLE_API_KEY, AGENT_MODEL_NAME, TAVILY_API_KEY |
| |
|
| | MULTIMODAL_FILE_ANALYZER_PROMPT = """ |
| | You are a specialized file analysis AI assistant focused on extracting information from various file formats including images, videos, audio, and structured data. |
| | Core Analysis Guidelines: |
| | - Systematic processing: Analyze file contents step by step |
| | - Precise responses: Provide answers in the most concise format - raw numbers, single words, or comma-delimited lists |
| | - Format requirements: |
| | * Numbers: No formatting (no commas, units, or symbols) |
| | * Lists: Pure comma-separated values |
| | * Text: Minimal words, no explanations |
| | - Analysis approach: |
| | * Images: Focus on visual elements, objects, text, and scene composition |
| | * Audio: Identify sounds, speech, music, and audio characteristics |
| | * Video: Analyze visual content, motion, and temporal elements |
| | * Excel/CSV: Extract relevant data points and patterns |
| | - Verification focus: Base answers solely on file contents |
| | - Answer format: Always prefix with 'FINAL ANSWER: ' |
| | - Counting tasks: Return only the count |
| | - Listing tasks: Return only the items |
| | - Sorting tasks: Return only the ordered list |
| | |
| | Example Responses: |
| | Q: Count people in image? A: 3 |
| | Q: List colors in logo? A: blue, red, white |
| | Q: Main topic of audio? A: weather forecast |
| | Q: Excel total sales? A: 15420 |
| | Q: Video duration? A: 45 |
| | """ |
| |
|
| |
|
| | class SmolagentToolWrapper(BaseTool): |
| | """Smol wrapper to allow Langchain/Graph to leverage smolagents tools""" |
| |
|
| | wrapped_tool: object = Field(description="Smolagents tool (wrapped)") |
| |
|
| | def __init__(self, tool): |
| | super().__init__( |
| | name=tool.name, |
| | description=tool.description, |
| | return_direct=False, |
| | wrapped_tool=tool, |
| | ) |
| |
|
| | def _run(self, query: str) -> str: |
| | try: |
| | return self.wrapped_tool(query) |
| | except Exception as e: |
| | return f"Error using SmolagentToolWrapper: {str(e)}" |
| |
|
| | def _arun(self, *args: typing.Any, **kwargs: typing.Any) -> typing.Any: |
| | """Async version of the tool""" |
| | return asyncify(self._run, cancellable=True)(*args, **kwargs) |
| |
|
| |
|
| | tavily_extract_tool = TavilyExtract(tavily_api_key=TAVILY_API_KEY) |
| |
|
| |
|
| | @tool("search-tavily-tool", parse_docstring=True) |
| | async def search_tavily( |
| | query: str, |
| | state: Annotated[dict, InjectedState], |
| | included_domains: list[str] = None, |
| | max_results: int = 5, |
| | ) -> dict[str, str]: |
| | """ |
| | Search the web using Tavily API with optional domain filtering. |
| | |
| | This function performs a search using the Tavily search engine and returns formatted results. |
| | You can specify domains to include in the search results for more targeted information. |
| | |
| | Args: |
| | query (str): The search query to search the web for |
| | included_domains (list[str], optional): List of domains to include in search results |
| | (e.g., ["wikipedia.org", "cnn.com"]). Defaults to None. |
| | max_results (int, optional): Maximum number of results to return. Defaults to 5. |
| | |
| | Returns: |
| | dict[str, str]: A dictionary with key 'tavily_results' containing formatted search results. |
| | Each result includes document source, page information, and content. |
| | |
| | Example: |
| | results = await search_tavily("How many albums did Michael Jackson produce", included_domains=[], topic="general") |
| | # Returns filtered results about Michael Jackson |
| | """ |
| | |
| | tavily_search_tool = TavilySearch( |
| | tavily_api_key=TAVILY_API_KEY, |
| | max_results=max_results, |
| | topic="general", |
| | include_domains=included_domains if included_domains else None, |
| | search_depth="advanced", |
| | include_answer="advanced", |
| | ) |
| |
|
| | |
| | search_docs = await tavily_search_tool.arun(state["question"]) |
| |
|
| | |
| | formatted_search_docs = "\n\n---\n\n".join( |
| | [ |
| | f'<Document source="{doc.get("url", "No URL")}"/>{doc.get("title", "No Title")}\n{doc.get("content", "")}\n</Document>' |
| | for doc in search_docs.get("results", []) |
| | ] |
| | ) |
| |
|
| | results = {"tavily_results": formatted_search_docs} |
| |
|
| | answer = search_docs.get("answer", None) |
| |
|
| | if answer: |
| | results["tavily_answer"] = answer |
| |
|
| | return results |
| |
|
| |
|
| | @tool("search-arxiv-tool", parse_docstring=True) |
| | async def search_arxiv(query: str, max_num_result: int = 5) -> dict[str, str]: |
| | """ |
| | Search arXiv for academic papers matching the provided query. |
| | This function queries the arXiv database for scholarly articles related to the |
| | search query and returns a formatted collection of the results. |
| | |
| | Args: |
| | query (str): The search query to find relevant academic papers. |
| | max_num_result (int, optional): Maximum number of results to return. Defaults to 5. |
| | |
| | Returns: |
| | dict[str, str]: A dictionary with key 'arxiv_results' containing formatted search results. |
| | Each result includes document source, page information, and content. |
| | |
| | Example: |
| | results = await search_arxiv("quantum computing", 3) |
| | # Returns dictionary with up to 3 formatted arXiv papers about quantum computing |
| | """ |
| | search_docs = await ArxivLoader(query=query, load_max_docs=max_num_result).aload() |
| | formatted_search_docs = "\n\n---\n\n".join( |
| | [ |
| | f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' |
| | for doc in search_docs |
| | ] |
| | ) |
| | return {"arvix_results": formatted_search_docs} |
| |
|
| |
|
| | @tool("search-wikipedia-tool", parse_docstring=True) |
| | async def search_wikipedia(query: str, max_num_result: int = 5) -> dict[str, str]: |
| | """ |
| | Search Wikipedia for articles matching the provided query. |
| | This function queries the Wikipedia database for articles related to the |
| | search term and returns a formatted collection of the results. |
| | |
| | Args: |
| | query (str): The search query to find relevant Wikipedia articles. |
| | max_num_result (int, optional): Maximum number of results to return. Defaults to 5. |
| | |
| | Returns: |
| | dict[str, str]: A dictionary with key 'wikipedia_results' containing formatted search results. |
| | Each result includes document source, page information, and content. |
| | |
| | Example: |
| | results = await search_wikipedia("neural networks", 3) |
| | # Returns dictionary with up to 3 formatted Wikipedia articles about neural networks |
| | """ |
| | search_docs = await WikipediaLoader( |
| | query=query, |
| | load_max_docs=max_num_result, |
| | load_all_available_meta=True, |
| | doc_content_chars_max=128000, |
| | ).aload() |
| |
|
| | |
| |
|
| | formatted_search_docs = "\n\n---\n\n".join( |
| | [ |
| | f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>' |
| | for doc in search_docs |
| | ] |
| | ) |
| | return {"wikipedia_results": formatted_search_docs} |
| |
|
| |
|
| | @tool("download-file-for-task-tool", parse_docstring=True) |
| | async def download_file_for_task(task_id: str, filename: str | None = None) -> str: |
| | """ |
| | Download a file for task_id, save to a temporary file, and return path |
| | |
| | Args: |
| | task_id: The task id file to download |
| | filename: Optional filename (will be generated if not provided) |
| | |
| | Returns: |
| | String path to the downloaded file |
| | """ |
| | if filename is None: |
| | filename = task_id |
| |
|
| | temp_dir = Path(tempfile.gettempdir()) |
| | filepath = temp_dir / filename |
| |
|
| | url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}" |
| | async with httpx.AsyncClient() as client: |
| | async with client.stream("GET", url) as response: |
| | response.raise_for_status() |
| | async with await filepath.open("wb") as f: |
| | async for chunk in response.aiter_bytes(chunk_size=4096): |
| | await f.write(chunk) |
| |
|
| | return str(filepath) |
| |
|
| |
|
| | @tool("read-file-contents-tool", parse_docstring=True) |
| | async def read_file_contents(file_path: str) -> str: |
| | """ |
| | Read a file and return its contents |
| | |
| | Args: |
| | file_path: String path to file to read |
| | |
| | Returns: |
| | Contents of the file at file_path |
| | """ |
| | path = Path(file_path) |
| | return await path.read_text() |
| |
|
| |
|
| | @tool("analyze-image-tool", parse_docstring=True) |
| | async def analyze_image(state: Annotated[dict, InjectedState], image_path: str) -> str: |
| | """ |
| | Analyze the image at image_path |
| | |
| | Args: |
| | image_path: String path where the image file is located on disk |
| | |
| | Returns: |
| | Answer to the question about the image file |
| | """ |
| | path = Path(image_path) |
| | async with await path.open("rb") as rb: |
| | img_base64 = b64encode(await rb.read()).decode("utf-8") |
| |
|
| | llm = get_llm( |
| | llm_provider_api_key=GOOGLE_API_KEY, |
| | model_name=AGENT_MODEL_NAME, |
| | ) |
| |
|
| | file_agent = create_react_agent( |
| | model=llm, |
| | tools=[], |
| | prompt=MULTIMODAL_FILE_ANALYZER_PROMPT |
| | ) |
| |
|
| | message = HumanMessage( |
| | content=[ |
| | {"type": "text", "text": state["question"]}, |
| | { |
| | "type": "image", |
| | "source_type": "base64", |
| | "mime_type": "image/png", |
| | "data": img_base64, |
| | }, |
| | ] |
| | ) |
| |
|
| | messages = await file_agent.ainvoke({"messages": [message]}) |
| | return messages["messages"][-1].content |
| |
|
| |
|
| | @tool("analyze-excel-tool", parse_docstring=True) |
| | async def analyze_excel(state: Annotated[dict, InjectedState], excel_path: str) -> str: |
| | """ |
| | Analyze the excel file at excel_path |
| | |
| | Args: |
| | excel_path: String path where the excel file is located on disk |
| | |
| | Returns: |
| | Answer to the question about the excel file |
| | """ |
| |
|
| | df = pd.read_excel(excel_path) |
| |
|
| | csv_buffer = StringIO() |
| | df.to_csv(csv_buffer, index=False) |
| |
|
| | csv_contents = csv_buffer.getvalue() |
| | csv_contents_bytes = csv_contents.encode("utf-8") |
| | csv_contents_base64 = b64encode(csv_contents_bytes).decode("utf-8") |
| |
|
| | llm = get_llm( |
| | llm_provider_api_key=GOOGLE_API_KEY, |
| | model_name=AGENT_MODEL_NAME, |
| | ) |
| |
|
| | file_agent = create_react_agent( |
| | model=llm, |
| | tools=[], |
| | prompt=MULTIMODAL_FILE_ANALYZER_PROMPT |
| | ) |
| |
|
| | message = HumanMessage( |
| | content=[ |
| | {"type": "text", "text": state["question"]}, |
| | { |
| | "type": "file", |
| | "source_type": "base64", |
| | "mime_type": "text/csv", |
| | "data": csv_contents_base64, |
| | }, |
| | ], |
| | ) |
| |
|
| | messages = await file_agent.ainvoke({"messages": [message]}) |
| | return messages["messages"][-1].content |
| |
|
| |
|
| | @tool("analyze-audio-tool", parse_docstring=True) |
| | async def analyze_audio(state: Annotated[dict, InjectedState], audio_path: str) -> str: |
| | """ |
| | Analyze the audio at audio_path |
| | |
| | Args: |
| | audio_path: String path where the audio file is located on disk |
| | |
| | Returns: |
| | Answer to the question about the audio file |
| | """ |
| | audio_mime_type = "audio/mpeg" |
| |
|
| | path = Path(audio_path) |
| |
|
| | async with await path.open("rb") as rb: |
| | encoded_audio = b64encode(await rb.read()).decode("utf-8") |
| |
|
| | llm = get_llm( |
| | llm_provider_api_key=GOOGLE_API_KEY, |
| | model_name=AGENT_MODEL_NAME, |
| | ) |
| |
|
| | file_agent = create_react_agent( |
| | model=llm, |
| | tools=[], |
| | prompt=MULTIMODAL_FILE_ANALYZER_PROMPT |
| | ) |
| |
|
| | message = HumanMessage( |
| | content=[ |
| | {"type": "text", "text": state["question"]}, |
| | {"type": "media", "data": encoded_audio, "mime_type": audio_mime_type}, |
| | ], |
| | ) |
| |
|
| | messages = await file_agent.ainvoke({"messages": [message]}) |
| | return messages["messages"][-1].content |
| |
|
| |
|
| | @tool("analyze-video-tool", parse_docstring=True) |
| | async def analyze_video(state: Annotated[dict, InjectedState], video_url: str) -> str: |
| | """ |
| | Analyze the video at video_url |
| | |
| | Args: |
| | video_url: URL where the video is located |
| | |
| | Returns: |
| | Answer to the question about the video url |
| | """ |
| | llm = get_llm( |
| | llm_provider_api_key=GOOGLE_API_KEY, |
| | model_name=AGENT_MODEL_NAME, |
| | ) |
| |
|
| | file_agent = create_react_agent( |
| | model=llm, |
| | tools=[], |
| | prompt=MULTIMODAL_FILE_ANALYZER_PROMPT |
| | ) |
| |
|
| | message = HumanMessage( |
| | content=[ |
| | {"type": "text", "text": state["question"]}, |
| | { |
| | "type": "media", |
| | "mime_type": "video/mp4", |
| | "file_uri": video_url, |
| | }, |
| | ], |
| | ) |
| |
|
| | messages = await file_agent.ainvoke({"messages": [message]}) |
| | return messages["messages"][-1].content |
| |
|