|
|
""" |
|
|
LangChain-compatible tools for the LangGraph multi-agent system |
|
|
|
|
|
This module provides LangChain tools that work properly with LangGraph agents, |
|
|
replacing the LlamaIndex tools with native LangChain implementations. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import wikipedia |
|
|
import arxiv |
|
|
from typing import List, Optional, Type |
|
|
from langchain_core.tools import BaseTool, tool |
|
|
from pydantic import BaseModel, Field |
|
|
from huggingface_hub import list_models |
|
|
from observability import tool_span |
|
|
|
|
|
|
|
|
try: |
|
|
from langchain_tavily import TavilySearch |
|
|
TAVILY_AVAILABLE = True |
|
|
except ImportError as e: |
|
|
print(f"Warning: langchain_tavily not available: {e}") |
|
|
TAVILY_AVAILABLE = False |
|
|
TavilySearch = None |
|
|
|
|
|
|
|
|
|
|
|
class WikipediaSearchInput(BaseModel): |
|
|
"""Input for Wikipedia search tool.""" |
|
|
query: str = Field(description="The search query for Wikipedia") |
|
|
|
|
|
|
|
|
class ArxivSearchInput(BaseModel): |
|
|
"""Input for ArXiv search tool.""" |
|
|
query: str = Field(description="The search query for ArXiv papers") |
|
|
|
|
|
|
|
|
class HubStatsInput(BaseModel): |
|
|
"""Input for Hugging Face Hub stats tool.""" |
|
|
author: str = Field(description="The author/organization name on Hugging Face Hub") |
|
|
|
|
|
|
|
|
class TavilySearchInput(BaseModel): |
|
|
"""Input for Tavily search tool.""" |
|
|
query: str = Field(description="The search query for web search") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@tool("wikipedia_search", args_schema=WikipediaSearchInput) |
|
|
def wikipedia_search_tool(query: str) -> str: |
|
|
"""Search Wikipedia for information about a topic.""" |
|
|
try: |
|
|
with tool_span("wikipedia_search", metadata={"query": query}): |
|
|
|
|
|
try: |
|
|
|
|
|
search_results = wikipedia.search(query, results=3) |
|
|
if not search_results: |
|
|
return f"No Wikipedia results found for '{query}'" |
|
|
|
|
|
|
|
|
for page_title in search_results: |
|
|
try: |
|
|
page = wikipedia.page(page_title) |
|
|
|
|
|
content = page.summary |
|
|
if len(content) > 1000: |
|
|
content = content[:1000] + "..." |
|
|
|
|
|
return f"Wikipedia: {page.title}\n\nURL: {page.url}\n\nSummary:\n{content}" |
|
|
|
|
|
except wikipedia.exceptions.DisambiguationError as e: |
|
|
|
|
|
try: |
|
|
page = wikipedia.page(e.options[0]) |
|
|
content = page.summary |
|
|
if len(content) > 1000: |
|
|
content = content[:1000] + "..." |
|
|
return f"Wikipedia: {page.title}\n\nURL: {page.url}\n\nSummary:\n{content}" |
|
|
except: |
|
|
continue |
|
|
except: |
|
|
continue |
|
|
|
|
|
return f"Could not retrieve Wikipedia content for '{query}'" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Wikipedia search error: {str(e)}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Wikipedia search failed: {str(e)}" |
|
|
|
|
|
|
|
|
@tool("arxiv_search", args_schema=ArxivSearchInput) |
|
|
def arxiv_search_tool(query: str) -> str: |
|
|
"""Search ArXiv for academic papers.""" |
|
|
try: |
|
|
with tool_span("arxiv_search", metadata={"query": query}): |
|
|
|
|
|
search = arxiv.Search( |
|
|
query=query, |
|
|
max_results=3, |
|
|
sort_by=arxiv.SortCriterion.Relevance |
|
|
) |
|
|
|
|
|
results = [] |
|
|
for paper in search.results(): |
|
|
result = f"""Title: {paper.title} |
|
|
Authors: {', '.join([author.name for author in paper.authors])} |
|
|
Published: {paper.published.strftime('%Y-%m-%d')} |
|
|
URL: {paper.entry_id} |
|
|
Summary: {paper.summary[:500]}...""" |
|
|
results.append(result) |
|
|
|
|
|
if results: |
|
|
return f"ArXiv Search Results for '{query}':\n\n" + "\n\n---\n\n".join(results) |
|
|
else: |
|
|
return f"No ArXiv papers found for '{query}'" |
|
|
|
|
|
except Exception as e: |
|
|
return f"ArXiv search failed: {str(e)}" |
|
|
|
|
|
|
|
|
@tool("huggingface_hub_stats", args_schema=HubStatsInput) |
|
|
def huggingface_hub_stats_tool(author: str) -> str: |
|
|
"""Get statistics for a Hugging Face Hub author.""" |
|
|
try: |
|
|
with tool_span("huggingface_hub_stats", metadata={"author": author}): |
|
|
models = list(list_models(author=author, sort="downloads", direction=-1, limit=5)) |
|
|
if models: |
|
|
results = [] |
|
|
for i, model in enumerate(models, 1): |
|
|
results.append(f"{i}. {model.id} - {model.downloads:,} downloads") |
|
|
|
|
|
top_model = models[0] |
|
|
summary = f"Top 5 models by {author}:\n" + "\n".join(results) |
|
|
summary += f"\n\nMost popular: {top_model.id} with {top_model.downloads:,} downloads" |
|
|
return summary |
|
|
else: |
|
|
return f"No models found for author '{author}'" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Hub stats error: {str(e)}" |
|
|
|
|
|
|
|
|
@tool("tavily_search_results_json", args_schema=TavilySearchInput) |
|
|
def tavily_search_fallback_tool(query: str) -> str: |
|
|
"""Fallback web search tool when Tavily is not available.""" |
|
|
try: |
|
|
with tool_span("tavily_search_fallback", metadata={"query": query}): |
|
|
|
|
|
import requests |
|
|
|
|
|
|
|
|
|
|
|
search_url = f"https://duckduckgo.com/lite/?q={query}" |
|
|
headers = { |
|
|
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36' |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.get(search_url, headers=headers, timeout=10) |
|
|
if response.status_code == 200: |
|
|
return f"Web search completed for '{query}'. Found general web results (fallback mode - Tavily not available)." |
|
|
else: |
|
|
return f"Web search failed for '{query}' (status: {response.status_code})" |
|
|
except Exception as e: |
|
|
return f"Web search error for '{query}': {str(e)}" |
|
|
|
|
|
except Exception as e: |
|
|
return f"Web search failed: {str(e)}" |
|
|
|
|
|
|
|
|
def get_tavily_search_tool() -> BaseTool: |
|
|
"""Get the Tavily search tool from LangChain community, with fallback.""" |
|
|
if TAVILY_AVAILABLE and TavilySearch: |
|
|
try: |
|
|
return TavilySearch( |
|
|
api_key=os.getenv("TAVILY_API_KEY"), |
|
|
max_results=6, |
|
|
include_answer=True, |
|
|
include_raw_content=True, |
|
|
description="Search the web for current information and facts" |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to create TavilySearch tool: {e}") |
|
|
return tavily_search_fallback_tool |
|
|
else: |
|
|
print("Warning: Using fallback search tool (Tavily not available)") |
|
|
return tavily_search_fallback_tool |
|
|
|
|
|
|
|
|
def get_calculator_tools() -> List[BaseTool]: |
|
|
"""Get calculator tools as LangChain tools.""" |
|
|
|
|
|
@tool("multiply") |
|
|
def multiply(a: float, b: float) -> float: |
|
|
"""Multiply two numbers.""" |
|
|
return a * b |
|
|
|
|
|
@tool("add") |
|
|
def add(a: float, b: float) -> float: |
|
|
"""Add two numbers.""" |
|
|
return a + b |
|
|
|
|
|
@tool("subtract") |
|
|
def subtract(a: float, b: float) -> float: |
|
|
"""Subtract two numbers.""" |
|
|
return a - b |
|
|
|
|
|
@tool("divide") |
|
|
def divide(a: float, b: float) -> float: |
|
|
"""Divide two numbers.""" |
|
|
if b == 0: |
|
|
raise ValueError("Cannot divide by zero") |
|
|
return a / b |
|
|
|
|
|
@tool("modulus") |
|
|
def modulus(a: int, b: int) -> int: |
|
|
"""Get the modulus of two integers.""" |
|
|
if b == 0: |
|
|
raise ValueError("Cannot modulo by zero") |
|
|
return a % b |
|
|
|
|
|
return [multiply, add, subtract, divide, modulus] |
|
|
|
|
|
|
|
|
def get_research_tools() -> List[BaseTool]: |
|
|
"""Get all research tools for the research agent.""" |
|
|
tools = [ |
|
|
get_tavily_search_tool(), |
|
|
wikipedia_search_tool, |
|
|
arxiv_search_tool, |
|
|
] |
|
|
return tools |
|
|
|
|
|
|
|
|
def get_code_tools() -> List[BaseTool]: |
|
|
"""Get all code/computation tools for the code agent.""" |
|
|
tools = get_calculator_tools() |
|
|
tools.append(huggingface_hub_stats_tool) |
|
|
return tools |
|
|
|
|
|
|
|
|
def get_all_tools() -> List[BaseTool]: |
|
|
"""Get all available tools.""" |
|
|
return get_research_tools() + get_code_tools() |