agent_test / agent.py
blazingbunny's picture
Upload 3 files
413f406 verified
from typing import TypedDict, Annotated, List
import operator
import os
import base64
import requests
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
from langgraph.graph import StateGraph, END, START
from langgraph.prebuilt import ToolNode
from langchain_core.tools import tool
from langchain_community.document_loaders import YoutubeLoader, WikipediaLoader
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_experimental.utilities import PythonREPL
from langchain_chroma import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.tools import tool
from langchain_community.tools import YouTubeSearchTool
# Playwright Imports (Optional)
try:
from langchain_community.agent_toolkits import PlaywrightBrowserToolkit
from langchain_community.tools.playwright.utils import create_sync_playwright_browser
except ImportError:
PlaywrightBrowserToolkit = None
create_sync_playwright_browser = None
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
from dotenv import load_dotenv
load_dotenv()
# Configure tracing
try:
if os.getenv("ARIZE_SPACE_ID") and os.getenv("ARIZE_API_KEY"):
from arize.otel import register
from openinference.instrumentation.google_genai import GoogleGenAIInstrumentor
from openinference.instrumentation.langchain import LangChainInstrumentor
tracer_provider = register(
space_id=os.getenv("ARIZE_SPACE_ID"),
api_key=os.getenv("ARIZE_API_KEY"),
project_name=os.getenv("ARIZE_PROJECT_NAME", "langgraph-agent-test")
)
GoogleGenAIInstrumentor().instrument(tracer_provider=tracer_provider)
LangChainInstrumentor().instrument(tracer_provider=tracer_provider)
print("Tracing configured with Arize.")
else:
print("Arize tracing skipped: ARIZE_SPACE_ID or ARIZE_API_KEY not set.")
except ImportError:
print("Tracing libraries not installed. Skipping tracing.")
except Exception as e:
print(f"Error configuring tracing: {e}")
# 1. Define the state
class AgentState(TypedDict):
messages: Annotated[List[BaseMessage], operator.add]
# Helper to split and save documents to Chroma
def save_to_chroma(docs):
if 'vector_store' in globals() and vector_store and docs:
try:
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
splits = splitter.split_documents(docs)
if splits:
vector_store.add_documents(splits)
except Exception as e:
print(f"Error saving to Chroma: {e}")
# 2. Define the tools
@tool
def get_youtube_transcript(url: str) -> str:
"""Retrieves the transcript of a YouTube video given its URL."""
try:
loader = YoutubeLoader.from_youtube_url(url, add_video_info=True)
docs = loader.load()
if not docs:
return "No transcript found. Please search Google for the video title or ID."
# Save to Chroma
save_to_chroma(docs)
return "\n\n".join([f"Metadata: {d.metadata}\nContent: {d.page_content}" for d in docs])
except Exception as e:
return f"Error getting transcript: {e}. Please try searching Google for the video URL or ID."
@tool
def calculator(expression: str) -> str:
"""Calculates a mathematical expression using Python. Example: '2 + 2', '34 * 5', 'import math; math.sqrt(2)'"""
try:
repl = PythonREPL()
if "print" not in expression:
expression = f"print({expression})"
return repl.run(expression)
except Exception as e:
return f"Error calculating: {e}"
@tool
def search_wikipedia(query: str) -> str:
"""Search Wikipedia for a query. Useful for factual lists and biographies."""
try:
loader = WikipediaLoader(query=query, load_max_docs=3)
docs = loader.load()
# Save to Chroma
save_to_chroma(docs)
return "\n\n".join([d.page_content[:10000] for d in docs])
except Exception as e:
return f"Error searching Wikipedia: {e}"
# ChromaDB RAG Tool
vector_store = None
try:
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vector_store = Chroma(
collection_name="agent_memory",
embedding_function=embeddings,
persist_directory="./chroma_db"
)
except Exception as e:
print(f"Warning: ChromaDB initialization failed. RAG features disabled. Error: {e}")
@tool
def search_knowledge_base(query: str) -> str:
"""Searches for relevant documents in the persistent knowledge base (memory of previous searches)."""
try:
retriever = vector_store.as_retriever()
docs = retriever.invoke(query)
if not docs:
return "No relevant information found."
return "\n".join([d.page_content for d in docs])
except Exception as e:
return f"Error searching knowledge base: {e}"
@tool
def browse_page(url: str) -> str:
"""Browses a web page and extracts text using Playwright. Use this to read content from specific URLs."""
if not create_sync_playwright_browser:
return "Browsing unavailable (Playwright not installed)."
try:
browser = create_sync_playwright_browser(headless=True)
page = browser.new_page()
page.goto(url)
text = page.inner_text("body")
browser.close()
# Save to Chroma
if 'vector_store' in globals() and vector_store:
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
docs = [Document(page_content=text, metadata={"source": url})]
splits = splitter.split_documents(docs)
vector_store.add_documents(splits)
return text[:10000]
except Exception as e:
return f"Error browsing: {e}"
@tool
def search_youtube_videos(query: str) -> str:
"""Search for YouTube videos. Provide only the search keywords."""
try:
tool = YouTubeSearchTool()
return tool.run(f"{query}, 3")
except Exception as e:
return f"Error searching YouTube: {e}"
# Combine Tools (Native Google Search is enabled via model param)
# Removed rag_tool/knowledge_base as it was empty -> Adding it back now
tools = [get_youtube_transcript, calculator, search_wikipedia, search_knowledge_base, search_youtube_videos, browse_page]
tool_node = ToolNode(tools)
# 3. Define the model
LLM = "gemini-2.0-flash"
model = ChatGoogleGenerativeAI(
model=LLM,
temperature=0,
max_retries=5,
google_search_retrieval=True
)
model = model.bind_tools(tools)
# 4. Define the agent node
def should_continue(state):
messages = state['messages']
last_message = messages[-1]
if not last_message.tool_calls:
return "end"
else:
return "continue"
def call_model(state):
messages = state['messages']
response = model.invoke(messages)
return {"messages": [response]}
# 5. Create the graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)
workflow.add_edge(START, "agent")
workflow.add_conditional_edges("agent", should_continue, {"continue": "action", "end": END})
workflow.add_edge("action", "agent")
app = workflow.compile()
class LangGraphAgent:
def __init__(self):
self.app = app
def __call__(self, question: str, task_id: str = None) -> str:
messages = [
SystemMessage(content="""You are a helpful assistant with multimodal capabilities (Vision, Audio, PDF analysis).
Step 1: ALWAYS START by performing a Google Search (or using Wikipedia/YouTube) to gather up-to-date information. Do not answer from memory.
Step 2: If a URL is provided, search for the **EXACT URL** string on Google first to identify the video/page title. Do not add keywords yet. **DO NOT use the 'youtube_search' tool for this step; use Google Search.**
Step 3: Once you have the title, search for that title to find descriptions or summaries.
Step 4: Analyze the information found. If you cannot access a specific page or video directly (e.g. empty transcript), DO NOT GIVE UP. Use Google Search to find descriptions, summaries, or discussions from reliable sources.
Step 5: If you identify relevant Wikipedia pages or YouTube videos, use the specific tools ('search_wikipedia', 'get_youtube_transcript') to ingest them into your Knowledge Base.
Step 6: Reason to find the exact answer. Verify your findings by cross-referencing multiple sources if possible. You can use 'search_knowledge_base' to connect facts you have saved.
Step 7: Output the final answer strictly in this format:
FINAL ANSWER: [ANSWER]
Do not include "FINAL ANSWER:" in the [ANSWER] part itself.
Example:
Thinking: ...
FINAL ANSWER: 3
If the question involves an image, video, or audio file provided in the context, analyze it to answer.
"""),
]
content = []
content.append({"type": "text", "text": question})
if task_id:
image_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
try:
# Check headers first
response = requests.head(image_url, timeout=5)
mime_type = response.headers.get("Content-Type", "")
# Allow images, audio, video, pdf
if response.status_code == 200 and any(t in mime_type for t in ["image/", "audio/", "video/", "application/pdf"]):
# Fetch the file
img_response = requests.get(image_url, timeout=10)
if img_response.status_code == 200:
file_data = base64.b64encode(img_response.content).decode("utf-8")
content.append({
"type": "image_url", # LangChain uses this key for multimodal data URI
"image_url": {"url": f"data:{mime_type};base64,{file_data}"}
})
except Exception as e:
print(f"Error checking/fetching file: {e}")
messages.append(HumanMessage(content=content))
inputs = {"messages": messages}
final_state = self.app.invoke(inputs)
result = final_state['messages'][-1].content
def extract_text(content):
if isinstance(content, str):
return content
if isinstance(content, list):
return " ".join([extract_text(c) for c in content])
if isinstance(content, dict):
return content.get('text', str(content))
return str(content)
text_result = extract_text(result)
if "FINAL ANSWER:" in text_result:
return text_result.split("FINAL ANSWER:")[-1].strip()
return text_result