sqfoo's picture
Update agent.py
b0ded06 verified
raw
history blame
7.18 kB
import os
from typing import TypedDict, List, Dict, Any, Optional
from langchain.agents import create_tool_calling_agent, AgentExecutor, initialize_agent
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.tools import tool
from langchain_core.messages import HumanMessage
from langchain_core.prompts import ChatPromptTemplate
# 1. Web Browsing
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.document_loaders import ImageCaptionLoader
import requests, time
import pandas as pd
from pypdf import PdfReader
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
from youtube_transcript_api import YouTubeTranscriptApi
@tool
def web_search(query: str) -> str:
"""Allows search through DuckDuckGo.
Args:
query: what you want to search
"""
search = DuckDuckGoSearchRun()
results = search.invoke(query)
return "\n".join(results)
@tool
def visit_webpage(url: str) -> str:
"""Fetches raw HTML content of a web page.
Args:
url: the webpage url
"""
try:
response = requests.get(url, timeout=5)
return response.text
except Exception as e:
return f"[ERROR fetching {url}]: {str(e)}"
@tool
def wiki_search(query: str) -> str:
"""Wiki search tools.
Args:
query: what you want to wiki
"""
api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=100)
wikipediatool = WikipediaQueryRun(api_wrapper=api_wrapper)
return wikipediatool.run({"query": query})
@tool
def youtube_transcript(video_url: str) -> str:
"""Fetched youtube transcript
Args:
video_url: YouTube video url
"""
try:
video_id = video_url.split("v=")[-1].split("&")[0]
transcript = YouTubeTranscriptApi.get_transcript(video_id)
return " ".join([item["text"] for item in transcript])
except Exception as e:
return f"Error fetching transcript: {str(e)}"
# 4. File Reading
@tool
def read_file(dir: str) -> str:
"""Read the content of the provided file
Args:
dir: the filepath
"""
extension = dir.split['.'][-1]
if extension == 'xlsx':
dataframe = pd.read_excel(dir)
return dataframe.to_string()
elif extension == 'pdf':
reader = PdfReader(dir)
contents = [p.extract_text() for p in reader.pages]
return "\n".join(contents)
else:
with open(dir) as f:
return f.read()
# 5. Image Open
@tool
def image_caption(dir: str) -> str:
"""Understand the content of the provided image
Args:
dir: the image url link
"""
loader = ImageCaptionLoader(images=[dir])
metadata = loader.load()
return metadata[0].page_content
# 2. Coding
# 3. Multi-Modality
# ("human", f"Question: {question}\nReport to validate: {final_answer}")
class BasicAgent:
def __init__(self):
self.model = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
temperature=0,
max_tokens=128,
timeout=None,
max_retries=2,
google_api_key="AIzaSyAxVUPaGJIgdxB46ZR0RWPKSjB9a63Z80o",
# other params...
)
# System Prompt for few shot prompting
self.sys_prompt = """"
You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER].
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separared list of numbers and/or strings.
If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
If you are asked for a string, don't use articles, neither abbreviations (eg. for cities), and write the digits in plain text unless specified otherwise.
If you are asked for a comma separated list, apply the above rules depending of whether the element to put in the list is a number or a string.
You have access to the following tools:
- web_search: web search the content of the query by passing the query as input
- visit_webpage: visit the given webpage url by passing the url as input
- wiki_search: wiki search the content of the query by passing the query as input if the question asks for wiki search it
- youtube_transcript: fetch the transcript of the Youtube video by passing the video url as input if the question asks for watching a Youtube video
- read_file: read the content of the attached file by passing the file directory as input
- image_caption: understand the visual content of the attached image by passing the image directory as input
HERE are some examples illustrating how and what tools to call.
---------------
TASK: Count how many birds in the provided Youtube video.
ACTION: Call youtube_transcript tool to extract video transcript. Use LLM to understand the retrived transcript.
TASK: How many Grammy Awards that Taylor Swift has won.
ACTION: Call the web_search tools with the query: 'how many Grammy Awards that Taylor Swift has won.' to extract the answer.
TASK: Count how many people in this image.
ACTION: Call the image_caption tool by passing the image directory as input. Then, use LLM to understand the image caption and answer the question.
TASK: How much the total expense in this spreadsheet?
ACTION: Call the read_file tool to extract the content of the provided spreadfile. Then, use LLM to extract the amount of every expense and sum them up.
TASK: How many All England Title that Lee Chong Wei won?
ACTION: Call wiki_search with the query: "Lee Chong Wei". Extract the relevant row of All England Title and count how many rows is there.
"""
self.tools = [web_search, visit_webpage, wiki_search, youtube_transcript, read_file, image_caption]
self.prompt = ChatPromptTemplate.from_messages([
("system", self.sys_prompt),
("human", "{input}")
])
self.agent = initialize_agent(
tools=self.tools,
llm=self.model,
agent="zero-shot-react-description", # ReAct agent type
verbose=True,
system_prompt=self.prompt
)
print("BasicAgent initialized.")
def __call__(self, question: str) -> str:
print(f"Agent received question (first 50 chars): {question[:50]}...")
# response = self.agent_exe.invoke({"input": f"Question: {question}"})
# fixed_answer = response['message'][-1].content
time.sleep(15)
fixed_answer = self.agent.run(f"Answer this question: {question}")
# fixed_answer = "This is a default answer."
print(f"Agent returning fixed answer: {fixed_answer}")
return fixed_answer