AlaaWO's picture
Update agent.py
c23121a verified
"""langgraph ReAct LLAMA instruct agent"""
from dotenv import load_dotenv
import os
from typing import TypedDict, List, Dict, Any, Optional
from langchain_tavily import TavilySearch
from langchain_core.tools import tool
import requests
from urllib.parse import urlparse
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition,ToolNode
from langchain_core.messages import SystemMessage, HumanMessage
from langchain.schema import HumanMessage, SystemMessage
import json
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain.agents import initialize_agent
from langchain.agents.agent_types import AgentType
import pandas as pd
from langchain_community.document_loaders import WikipediaLoader
from langchain_community.document_loaders import ArxivLoader
import sympy
from sympy import sympify
load_dotenv()
@tool
def arvix_search(query: str) -> str:
"""
Search Arxiv for a query and return up to 3 results.
Args:
query: The search query.
Returns:
A string with formatted Arxiv search results (truncated to 1000 chars each).
"""
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
for doc in search_docs
]
)
return formatted_search_docs
@tool
def wiki_search(query: str) -> str:
"""
Search Wikipedia for a query and return up to 2 formatted results.
Args:
query: The search query.
Returns:
A string with formatted Wikipedia search results.
"""
search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
formatted_search_docs = "\n\n---\n\n".join(
[
f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
for doc in search_docs
]
)
return formatted_search_docs
@tool
def analyze_excel_file(input_str: str) -> str:
"""
Analyze an Excel file using pandas and answer a question about it.
Args:
input_str: JSON string with fields:
- file_path: Path to the Excel file
- query: A question about the file contents (optional)
Returns:
A summary of the file contents or an error message.
"""
try:
import json
import pandas as pd
# Parse JSON input
data = json.loads(input_str)
file_path = data.get("file_path")
query = data.get("query")
if not file_path:
return "Error: 'file_path' is required."
# Read the Excel file (all sheets)
xls = pd.ExcelFile(file_path)
sheet_names = xls.sheet_names
result = f"Excel file loaded with sheets: {', '.join(sheet_names)}.\n\n"
# Analyze the first sheet as default
df = pd.read_excel(xls, sheet_name=sheet_names[0])
result += f"First sheet '{sheet_names[0]}' loaded with {len(df)} rows and {len(df.columns)} columns.\n"
result += f"Columns: {', '.join(df.columns)}\n\n"
result += "Summary statistics:\n"
result += str(df.describe(include='all'))
if query:
result += f"\n\nQuery: {query} (No advanced query handling implemented yet.)"
return result
except json.JSONDecodeError:
return "Error: Input must be a valid JSON string with 'file_path' and optional 'query'."
except Exception as e:
return f"Error analyzing Excel file: {str(e)}"
@tool
def web_search(query: str) -> str:
"""
Perform a web search using Tavily and return the result.
"""
try:
search = TavilySearch()
result = search.invoke(query)
if isinstance(result, dict) and "results" in result:
docs = result["results"]
return "\n\n---\n\n".join(
[f"{doc['title']}\n{doc['url']}\n{doc['content']}" for doc in docs]
)
else:
return f"Error: Unexpected Tavily response format: {result}"
except Exception as e:
return f"Error using TavilySearch: {str(e)}"
@tool
def analyze_csv_file(input_str: str) -> str:
"""
Analyze a CSV file using pandas and answer a question about it.
Args:
input_str: JSON string with fields:
- file_path: Path to the CSV file
- query: A question about the file contents
Returns:
A basic analysis of the file or an error message
"""
try:
# Parse the JSON string
data = json.loads(input_str)
file_path = data.get("file_path")
query = data.get("query")
if not file_path:
return "Error: 'file_path' is required."
# Read the CSV
df = pd.read_csv(file_path)
# Basic metadata
result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
result += f"Columns: {', '.join(df.columns)}\n\n"
result += "Summary statistics:\n"
result += str(df.describe(include='all', datetime_is_numeric=True))
# Optionally handle a query (not implemented in detail here)
if query:
result += f"\n\nQuery: {query} (No logic implemented yet to answer it.)"
return result
except json.JSONDecodeError:
return "Error: Input must be a valid JSON string with 'file_path' and optional 'query'."
except Exception as e:
return f"Error analyzing CSV file: {str(e)}"
@tool
def download_file_from_url(input_str: str) -> str:
"""
Downloads a file from a URL and saves it in the 'saved_files' directory.
Args:
input_str (str): A JSON string with keys:
- "url": the URL to download from (required)
- "filename": optional filename to save as
Returns:
A message indicating success and file path, or an error message.
"""
try:
# Parse the input string
data = json.loads(input_str)
url = data.get("url")
filename = data.get("filename", None)
if not url:
return "Error: 'url' is required in the input JSON."
# Create directory if not exists
new_dir = os.path.join(os.getcwd(), "saved_files")
os.makedirs(new_dir, exist_ok=True)
# Generate filename if not provided
if not filename:
path = urlparse(url).path
filename = os.path.basename(path) or f"downloaded_{os.urandom(4).hex()}"
filepath = os.path.join(new_dir, filename)
# Download the file
response = requests.get(url, stream=True)
response.raise_for_status()
# Save the file
with open(filepath, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
return f"File downloaded to {filepath}. You can now process this file."
except json.JSONDecodeError:
return "Error: Invalid JSON input. Expected format: {\"url\": \"...\", \"filename\": \"optional_name\"}"
except Exception as e:
return f"Error: {str(e)}"
@tool
def find_file_for_question(input_str: str) -> str:
"""
Constructs a multimodal question prompt for the agent to answer.
Args:
input_str (str): JSON string with keys:
- task_id: ID of the file
- question: The actual question
- file_name: (optional) file name, if image is involved
Returns:
A full natural language prompt that includes the file URL if needed.
"""
try:
data = json.loads(input_str)
task_id = data.get("task_id")
question = data.get("question")
file_name = data.get("file_name")
if not task_id or not question:
return "Error: Missing 'task_id' or 'question' in input."
prompt = question
if file_name:
file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
prompt += f"\n\nImage file to consider: {file_url}"
return prompt
except json.JSONDecodeError:
return "Error: Invalid input. Provide JSON with 'task_id', 'question', and optional 'file_name'."
except Exception as e:
return f"Error: {str(e)}"
@tool
def calculate_math_expression(expr: str) -> str:
"""
Evaluate a symbolic math expression (e.g., algebraic, numeric, or arithmetic).
Use this tool if the input is a math expression like '2 + 3*sqrt(4)', 'sin(pi/2)', or '3 ** 2'.
Input:
A raw string expression. Example: '2 + 3 * sqrt(4)'
Returns:
A float result as a string if successful,
otherwise a string with the error message.
"""
try:
result = sympify(expr)
# Check if the result is an actual sympy object with evalf
if hasattr(result, "evalf"):
return str(result.evalf())
else:
return str(result) # Already a number or something that can't be evaluated further
except Exception as e:
return f"Error: {str(e)}"
class AgentState(TypedDict):
messages: str # The original input question
attachments: Dict[str, Any] # Attachments (e.g., images, files) related to the question
context: List[Dict] # Retrieved context (e.g., search results, documents)
reasoning: List[str] # Step-by-step reasoning traces
partial_answer: Optional[str] # Intermediate answer (if multi-step)
final_answer: Optional[str] # Final answer to return
tools_used: List[str] # Track which tools were called (for debugging)
tools = [
find_file_for_question,
analyze_excel_file,
analyze_csv_file,
web_search,
arvix_search,
wiki_search,
download_file_from_url,
calculate_math_expression]
# Build graph function
def build_graph():
"""Build the graph"""
llm = HuggingFaceEndpoint(
repo_id="meta-llama/Llama-4-Scout-17B-16E-Instruct",
temperature= 0,
provider="novita",
)
chat_model = ChatHuggingFace(llm=llm)
agent = initialize_agent(
tools=tools,
llm=chat_model,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
handle_parsing_errors=True
)
def assistant(state: AgentState):
system_prompt = f"""
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
"""
sys_msg = SystemMessage(content= system_prompt)
return {
"messages": [agent.invoke({"input": [sys_msg] + state["messages"]})],
}
builder = StateGraph(AgentState)
# Define nodes: these do the work
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(tools))
# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
# If the latest message requires a tool, route to tools
# Otherwise, provide a direct response
tools_condition,
)
builder.add_edge("tools", "assistant")
return builder.compile()
if __name__ == "__main__":
#test the agent with a sample question
question = "what was the first university in the world?"
messages = [HumanMessage(content=question)]
output = build_graph().invoke({"messages": messages})
#print out the response
for entry in output["messages"]:
for msg in entry["input"]:
if isinstance(msg, HumanMessage):
print("🧑 Human:", msg.content)
elif isinstance(msg, SystemMessage):
print("⚙️ System:", msg.content)
print("🤖 Output:", entry["output"])
print("-" * 50)