mutyamjai's picture
Update agent.py
be0bcb0 verified
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys
from langchain_core.tools import tool
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
import time
from youtube_transcript_api import YouTubeTranscriptApi
import re
import yaml
from langchain_core.messages import SystemMessage
from langgraph.graph import StateGraph, START
from langgraph.prebuilt import ToolNode, tools_condition
from typing import TypedDict, Annotated
from langgraph.graph.message import add_messages
from langchain_openai import ChatOpenAI
from langchain_ollama import ChatOllama
from langchain_groq import ChatGroq
from dotenv import load_dotenv
import os
from tools import tools
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.tools import tool
from duckduckgo_search import DDGS
from selenium.webdriver.chrome.service import Service
from webdriver_manager.chrome import ChromeDriverManager
driver = None
def get_driver():
global driver
if driver is None:
options = webdriver.ChromeOptions()
options.add_argument("--no-sandbox")
options.add_argument("--disable-dev-shm-usage")
options.add_argument("--headless") # 🔥 IMPORTANT
driver = webdriver.Chrome(
service=Service(ChromeDriverManager().install()),
options=options
)
return driver
@tool
def search_in_page(keyword: str, context_chars: int = 400) -> str:
"""
Search for a keyword in the current page and return all occurrences
with surrounding context. Use this to find specific information.
"""
d = get_driver()
matches = d.execute_script("""
const keyword = arguments[0];
const ctx = arguments[1];
const body = document.body.innerText;
const lower = body.toLowerCase();
const kw = keyword.toLowerCase();
let matches = [], idx = 0;
while ((idx = lower.indexOf(kw, idx)) !== -1) {
matches.push({
context: body.substring(Math.max(0, idx - ctx),
Math.min(body.length, idx + kw.length + ctx))
});
idx += kw.length;
if (matches.length >= 10) break; // cap at 10
}
return matches;
""", keyword, context_chars)
if not matches:
return f"Keyword '{keyword}' not found on page."
results = []
for i, m in enumerate(matches, 1):
results.append(f"[Match {i}]\n{m['context']}\n{'─'*40}")
return f"Found {len(matches)} match(es) for '{keyword}':\n\n" + "\n".join(results)
@tool
def go_back() -> str:
"""Goes back to previous page."""
driver = get_driver()
driver.back()
return "Went back to previous page"
@tool
def close_popups() -> str:
"""
Closes any visible modal or pop-up on the page. Use this to dismiss pop-up windows!
This does not work on cookie consent banners.
"""
driver = get_driver()
webdriver.ActionChains(driver).send_keys(Keys.ESCAPE).perform()
return "Attempted to close popups"
@tool
def go_to(url: str) -> str:
"""Open a webpage using URL."""
driver = get_driver()
driver.get(url)
WebDriverWait(driver, 10).until(
lambda d: d.execute_script("return document.readyState") == "complete"
)
return f"Opened {url}"
@tool
def click_element(selector: str, by: str = "css") -> str:
"""
Click an element. by = 'css', 'xpath', 'text', or 'id'.
Example: click_element("a[href*='trending']")
"""
d = get_driver()
by_map = {
"css": By.CSS_SELECTOR, "xpath": By.XPATH,
"id": By.ID, "text": By.LINK_TEXT
}
locator = by_map.get(by, By.CSS_SELECTOR)
el = WebDriverWait(d, 8).until(EC.element_to_be_clickable((locator, selector)))
el.click()
time.sleep(1)
return f"Clicked. Now at: {d.current_url}"
@tool
def get_page_text(max_chars: int = 8000) -> str:
"""Get the full visible text of the current page (truncated)."""
d = get_driver()
text = d.execute_script("return document.body.innerText;")
return text[:max_chars] if text else "No text found."
@tool
def get_elements_text(css_selector: str, limit: int = 10) -> str:
"""
Extract text from multiple matching elements.
Great for lists, tables, repo cards, profile stats, etc.
Example: get_elements_text("article.Box-row", limit=5)
"""
d = get_driver()
elements = d.find_elements(By.CSS_SELECTOR, css_selector)[:limit]
if not elements:
return f"No elements found for selector: {css_selector}"
results = [f"[{i+1}] {el.text.strip()}" for i, el in enumerate(elements) if el.text.strip()]
return "\n\n".join(results)
@tool
def get_youtube_transcript(video_url: str = None) -> str:
"""Get transcript of current or given YouTube video."""
driver = get_driver()
url = video_url or driver.current_url
match = re.search(r"(?:v=|youtu\.be/|embed/)([a-zA-Z0-9_-]{11})", url)
if not match:
return "FAILED: Could not extract video ID from URL"
video_id = match.group(1)
try:
ytt_api = YouTubeTranscriptApi() # instantiate first
fetched = ytt_api.fetch(video_id)
full_text = " ".join([snippet.text for snippet in fetched])
return full_text[:5000]
except Exception as e:
return f"FAILED: {str(e)}"
@tool
def web_search(query: str) -> str:
"""
Search the web for real-time information.
Use this when:
- You need up-to-date information
- The answer is not on the current webpage
- You need external knowledge
Input: a clear search query
Output: top search results with title, snippet, and URL
"""
results = []
try:
with DDGS() as ddgs:
for r in ddgs.text(query, max_results=5):
results.append(
f"Title: {r['title']}\n"
f"Snippet: {r['body']}\n"
f"URL: {r['href']}\n"
)
if not results:
return "No results found."
return "\n\n".join(results)
except Exception as e:
return f"Search error: {str(e)}"
toolsused = [
go_to,
click_element,
search_in_page,
get_elements_text,
get_page_text,
go_back,
close_popups,
get_youtube_transcript,
web_search
]
load_dotenv()
# Load prompt
with open("config.yaml", encoding="utf-8") as f:
SYSTEM_PROMPT = yaml.safe_load(f)["prompt"]
# LLM (local)
# llm = ChatOpenAI(
# base_url="https://router.huggingface.co/v1",
# api_key=os.getenv("HF_TOKEN"),
# model="openai/gpt-oss-120b:groq",
# )
llm = ChatGroq(model="qwen/qwen3-32b", temperature=0)
print(type(tools))
llm_with_tools = llm.bind_tools(toolsused)
# State
class AgentState(TypedDict):
messages: Annotated[list, add_messages]
# Assistant node
def assistant(state: AgentState):
sys_msg = SystemMessage(content=SYSTEM_PROMPT)
response = llm_with_tools.invoke(
[sys_msg] + state["messages"]
)
return {"messages": [response]}
# Graph
builder = StateGraph(AgentState)
builder.add_node("assistant", assistant)
builder.add_node("tools", ToolNode(toolsused))
builder.add_edge(START, "assistant")
builder.add_conditional_edges("assistant", tools_condition)
builder.add_edge("tools", "assistant")
agent = builder.compile()
class ToolLogger(BaseCallbackHandler):
def on_tool_start(self, serialized, input_str, **kwargs):
print(f"\n🔧 TOOL CALLED: {serialized['name']}")
print(f" INPUT: {input_str}")
def on_tool_end(self, output, **kwargs):
print(f" OUTPUT: {str(output)[:200]}") # first 200 chars
def run_agent(question: str) -> str:
response = agent.invoke(
{
"messages": [
{"role": "user", "content": question}
]
},
config={
"recursion_limit": 75,
"callbacks": [ToolLogger()]
}
)
print(response["messages"][-1].content)
return response["messages"][-1].content