SerotoninRonin's picture
Add YouTube transcript extraction tool and update imports
da66358
import os
from typing import Any, List, TypedDict
from langgraph.graph import StateGraph
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.rate_limiters import InMemoryRateLimiter
from langchain_openai import ChatOpenAI
from tools import (
describe_image_tool,
parse_excel_tool,
webpage_extraction_tool,
brave_web_search,
python_code_interpreter_tool,
audio_file_transcriber,
get_youtube_transcript
)
class AgentState(TypedDict):
messages: List[Any]
question: str
file_path: str
final_answer: str
tools = [
describe_image_tool,
parse_excel_tool,
webpage_extraction_tool,
brave_web_search,
python_code_interpreter_tool,
audio_file_transcriber,
get_youtube_transcript
]
rate_limiter = InMemoryRateLimiter(
requests_per_second=0.1, # <-- Can only make a request once every 10 seconds!!
check_every_n_seconds=0.1, # Wake up every 100 ms to check whether allowed to make a request,
max_bucket_size=10, # Controls the maximum burst size.
)
class LangGraphAgent:
def __init__(self, model_name: str = "gpt-4o",):
self.llm = ChatOpenAI(model=model_name, max_tokens=2000, temperature=0, rate_limiter=rate_limiter)
self.llm_with_tools = self.llm.bind_tools(tools)
self.tool_node = ToolNode(tools)
self.graph = self.create_graph().compile()
def create_graph(self) -> StateGraph:
"""Creates a state graph for the agent's workflow."""
# Define the tools and their respective states
graph = StateGraph(AgentState)
graph.add_node("agent", self.agent_node)
graph.add_node("tools", self.tool_node)
graph.set_entry_point("agent")
graph.add_conditional_edges("agent", tools_condition)
graph.add_edge("tools", "agent")
return graph
def agent_node(self, state: AgentState):
"""Creates a node for the agent that uses the model to respond to user queries."""
messages = state['messages']
if len(messages) == 1:
system_prompt = ("You are a helpful assistant that can answer questions using various tools. "
"You must answer the given question using as few words as possible, or the given format, if any."
"If the answer is a number, you must return the number only, do not include symbols or commas."
"If you need to search the web for information and aren't given a URL, always use a search tool before using a wepbage extraction tool so you always have a legit website."
"If given a Python file, execute it with the code interpreter tool (riza_exec_python)")
messages = [SystemMessage(system_prompt)] + messages
MAX_HISTORY = 3 # tune as needed
# Keep system + last N messages
if len(messages) > MAX_HISTORY:
messages = [m for m in messages if isinstance(m, SystemMessage)] + messages[-MAX_HISTORY:]
while True:
response = self.llm_with_tools.invoke(messages)
messages.append(response)
if not response.tool_calls:
break
tool_outputs = []
for call in response.tool_calls:
tool_output = self.tool_node.invoke({"messages": [response], "tool_call": call})
tool_outputs.extend(tool_output["messages"])
messages.extend(tool_outputs)
state["messages"] = messages
return {"messages": messages}
def run(self, question: str) -> str:
state = AgentState(messages=[HumanMessage(content=question)], question=question, final_answer=None)
result = self.graph.invoke(state)
final_message = result["messages"][-1]
if hasattr(final_message, 'content'):
result['final_answer'] = final_message.content
else:
result['final_answer'] = str(final_message)
return result['final_answer']