Spaces:
Running
Running
| import sys | |
| import os | |
| import uuid | |
| from dotenv import load_dotenv | |
| from typing import Annotated, List, Tuple | |
| from typing_extensions import TypedDict | |
| from langchain.tools import tool, BaseTool | |
| from langchain.schema import Document | |
| from langgraph.graph import StateGraph, START, END, MessagesState | |
| from langgraph.graph.message import add_messages | |
| from langgraph.prebuilt import ToolNode, tools_condition | |
| from langgraph.checkpoint.memory import MemorySaver | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, AIMessagePromptTemplate, HumanMessagePromptTemplate | |
| # from langchain.schema import SystemMessage, HumanMessage, AIMessage, ToolMessage | |
| from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| import json | |
| sys.path.append(os.path.abspath('..')) | |
| import src.utils.qdrant_manager as qm | |
| import prompts.system_prompts as sp | |
| load_dotenv('/Users/nadaa/Documents/code/py_innovations/srf_chatbot_v2/.env') | |
| class ToolManager: | |
| def __init__(self, collection_name="openai_large_chunks_1000char"): | |
| self.tools = [] | |
| self.qdrant = qm.QdrantManager(collection_name=collection_name) | |
| self.vectorstore = self.qdrant.get_vectorstore() | |
| self.add_tools() | |
| def get_tools(self): | |
| return self.tools | |
| def add_tools(self): | |
| def vector_search(query: str, k: int = 15) -> list[Document]: | |
| """Useful for simple queries. This tool will search a vector database for passages from the teachings of Paramhansa Yogananda and other publications from the Self Realization Fellowship (SRF). | |
| The user has the option to specify the number of passages they want the search to return, otherwise the number of passages will be set to the default value.""" | |
| retriever = self.vectorstore.as_retriever(search_kwargs={"k": k}) | |
| documents = retriever.invoke(query) | |
| return documents | |
| def multiple_query_vector_search(query: str, k: int = 15) -> list[Document]: | |
| """Useful when the user's query is vague, complex, or involves multiple concepts. | |
| This tool will write multiple versions of the user's query and search the vector database for relevant passages. | |
| Use this tool when the user asks for an in depth answer to their question.""" | |
| llm = ChatOpenAI(model="gpt-4o-mini", temperature=0.5) | |
| retriever_from_llm = MultiQueryRetriever.from_llm(retriever=self.vectorstore.as_retriever(), llm=llm) | |
| documents = retriever_from_llm.invoke(query) | |
| return documents | |
| self.tools.append(vector_search) | |
| self.tools.append(multiple_query_vector_search) | |
| class BasicToolNode: | |
| """A node that runs the tools requested in the last AIMessage.""" | |
| def __init__(self, tools: list) -> None: | |
| self.tools_by_name = {tool.name: tool for tool in tools} | |
| def __call__(self, inputs: dict): | |
| if messages := inputs.get("messages", []): | |
| message = messages[-1] | |
| else: | |
| raise ValueError("No message found in input") | |
| outputs = [] | |
| documents = [] | |
| for tool_call in message.tool_calls: | |
| tool_result = self.tools_by_name[tool_call["name"]].invoke( | |
| tool_call["args"] | |
| ) | |
| outputs.append( | |
| ToolMessage( | |
| content=str(tool_result), | |
| name=tool_call["name"], | |
| tool_call_id=tool_call["id"], | |
| ) | |
| ) | |
| documents += tool_result | |
| return {"messages": outputs, "documents": documents} | |
| class AgentState(TypedDict): | |
| messages: Annotated[list, add_messages] | |
| documents: list[Document] | |
| system_message: list[SystemMessage] | |
| system_message_dropdown: list[str] | |
| class GenericChatbot: | |
| def __init__( | |
| self, | |
| model: str = 'gpt-4o-mini', | |
| temperature: float = 0, | |
| max_messages: int = 10, | |
| ): | |
| self.llm = ChatOpenAI(model=model, temperature=temperature) | |
| self.tools = ToolManager().get_tools() | |
| self.llm_with_tools = self.llm.bind_tools(self.tools) | |
| self.max_messages = max_messages | |
| # Build the graph | |
| self.graph = self.build_graph() | |
| # Get the configurable | |
| self.config = self.get_configurable() | |
| def get_configurable(self): | |
| # This thread id is used to keep track of the chatbot's conversation | |
| self.thread_id = str(uuid.uuid4()) | |
| return {"configurable": {"thread_id": self.thread_id}} | |
| # Add the system message onto the llm | |
| ## THIS SHOULD BE REFACTORED SO THAT THE STATE ALWAYS HAS THE DEFINITIVE SYSTEM MESSAGE THAT SHOULD BE IN USE | |
| def chatbot(self, state: AgentState): | |
| messages = state["messages"] | |
| # Calculate total tokens in messages | |
| total_tokens = 0 | |
| for message in messages: | |
| # Rough estimate: 4 chars = 1 token | |
| total_tokens += len(str(message.content)) // 4 | |
| # If over 100k tokens, keep only essential messages | |
| if total_tokens > 100000: | |
| # Always keep system message if present | |
| new_messages = [] | |
| if messages and isinstance(messages[0], SystemMessage): | |
| new_messages.append(messages[0]) | |
| # Add the most recent messages that fit under token limit | |
| for message in reversed(messages): | |
| message_tokens = len(str(message.content)) // 4 | |
| if total_tokens - message_tokens > 100000: | |
| total_tokens -= message_tokens | |
| continue | |
| new_messages.insert(1 if len(new_messages) > 0 else 0, message) | |
| messages = new_messages | |
| # Inform user about truncation | |
| messages.append( | |
| AIMessage(content="I notice our conversation has gotten quite long. I've kept the most recent and relevant parts to ensure we can continue effectively.") | |
| ) | |
| return {"messages": [self.llm_with_tools.invoke(messages)]} | |
| def build_graph(self): | |
| # Add chatbot state | |
| graph_builder = StateGraph(AgentState) | |
| # Add nodes | |
| tool_node = BasicToolNode(tools=self.tools) | |
| # tool_node = ToolNode(self.tools) | |
| graph_builder.add_node("tools", tool_node) | |
| graph_builder.add_node("chatbot", self.chatbot) | |
| # Add a conditional edge wherein the chatbot can decide whether or not to go to the tools | |
| graph_builder.add_conditional_edges( | |
| "chatbot", | |
| tools_condition, | |
| ) | |
| # Add fixed edges | |
| graph_builder.add_edge(START, "chatbot") | |
| graph_builder.add_edge("tools", "chatbot") | |
| # Instantiate the memory saver | |
| memory = MemorySaver() | |
| # Compile the graph | |
| return graph_builder.compile(checkpointer=memory) | |