Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -13,9 +13,11 @@ from langchain_community.document_loaders import ArxivLoader
|
|
| 13 |
from langchain_core.messages import SystemMessage, HumanMessage
|
| 14 |
from langchain_core.tools import tool
|
| 15 |
from langchain.tools.retriever import create_retriever_tool
|
| 16 |
-
from langchain_community.vectorstores import Chroma
|
| 17 |
-
from langchain_core.documents import Document
|
| 18 |
-
import shutil
|
|
|
|
|
|
|
| 19 |
|
| 20 |
load_dotenv()
|
| 21 |
|
|
@@ -122,13 +124,14 @@ sys_msg = SystemMessage(content=system_prompt)
|
|
| 122 |
# --- Start ChromaDB Setup ---
|
| 123 |
# Define the directory for ChromaDB persistence
|
| 124 |
CHROMA_DB_DIR = "./chroma_db"
|
|
|
|
| 125 |
|
| 126 |
# Build embeddings (this remains the same)
|
| 127 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
|
| 128 |
|
| 129 |
# Initialize ChromaDB
|
| 130 |
-
# If the directory exists, load the existing vector store.
|
| 131 |
-
# Otherwise, create a new one and add
|
| 132 |
if os.path.exists(CHROMA_DB_DIR) and os.listdir(CHROMA_DB_DIR):
|
| 133 |
print(f"Loading existing ChromaDB from {CHROMA_DB_DIR}")
|
| 134 |
vector_store = Chroma(
|
|
@@ -136,40 +139,63 @@ if os.path.exists(CHROMA_DB_DIR) and os.listdir(CHROMA_DB_DIR):
|
|
| 136 |
embedding_function=embeddings
|
| 137 |
)
|
| 138 |
else:
|
| 139 |
-
print(f"Creating new ChromaDB at {CHROMA_DB_DIR} and
|
| 140 |
# Ensure the directory is clean before creating new
|
| 141 |
if os.path.exists(CHROMA_DB_DIR):
|
| 142 |
shutil.rmtree(CHROMA_DB_DIR)
|
| 143 |
os.makedirs(CHROMA_DB_DIR)
|
| 144 |
|
| 145 |
-
#
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
Document(page_content="What is the capital of France?", metadata={"source": "internal", "answer": "Paris"}),
|
| 149 |
-
Document(page_content="Who wrote Hamlet?", metadata={"source": "internal", "answer": "William Shakespeare"}),
|
| 150 |
-
Document(page_content="What is the highest mountain in the world?", metadata={"source": "internal", "answer": "Mount Everest"}),
|
| 151 |
-
Document(page_content="When was the internet invented?", metadata={"source": "internal", "answer": "The internet, as we know it, evolved from ARPANET in the late 1960s and early 1970s. The TCP/IP protocol, which forms the basis of the internet, was standardized in 1978."}),
|
| 152 |
-
Document(page_content="What is the square root of 64?", metadata={"source": "internal", "answer": "8"}),
|
| 153 |
-
Document(page_content="Who is the current president of the United States?", metadata={"source": "internal", "answer": "Joe Biden"}),
|
| 154 |
-
Document(page_content="What is the chemical symbol for water?", metadata={"source": "internal", "answer": "H2O"}),
|
| 155 |
-
Document(page_content="What is the largest ocean on Earth?", metadata={"source": "internal", "answer": "Pacific Ocean"}),
|
| 156 |
-
Document(page_content="What is the speed of light?", metadata={"source": "internal", "answer": "Approximately 299,792,458 meters per second in a vacuum."}),
|
| 157 |
-
Document(page_content="What is the capital of Sweden?", metadata={"source": "internal", "answer": "Stockholm"}),
|
| 158 |
-
]
|
| 159 |
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
# Create retriever tool using the Chroma vector store
|
| 169 |
-
retriever_tool = create_retriever_tool(
|
| 170 |
retriever=vector_store.as_retriever(),
|
| 171 |
-
name="Question_Search",
|
| 172 |
-
description="A tool to retrieve similar questions from a vector store
|
| 173 |
)
|
| 174 |
|
| 175 |
# Add the new retriever tool to your list of tools
|
|
@@ -182,21 +208,17 @@ tools = [
|
|
| 182 |
wiki_search,
|
| 183 |
web_search,
|
| 184 |
arvix_search,
|
| 185 |
-
retriever_tool,
|
| 186 |
]
|
| 187 |
|
| 188 |
# Build graph function
|
| 189 |
def build_graph(provider: str = "google"):
|
| 190 |
"""Build the graph"""
|
| 191 |
-
# Load environment variables from .env file
|
| 192 |
if provider == "google":
|
| 193 |
-
# Google Gemini
|
| 194 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
|
| 195 |
elif provider == "groq":
|
| 196 |
-
|
| 197 |
-
llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
|
| 198 |
elif provider == "huggingface":
|
| 199 |
-
# TODO: Add huggingface endpoint
|
| 200 |
llm = ChatHuggingFace(
|
| 201 |
llm=HuggingFaceEndpoint(
|
| 202 |
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
|
|
@@ -205,10 +227,9 @@ def build_graph(provider: str = "google"):
|
|
| 205 |
)
|
| 206 |
else:
|
| 207 |
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
|
| 208 |
-
|
| 209 |
llm_with_tools = llm.bind_tools(tools)
|
| 210 |
|
| 211 |
-
# Node
|
| 212 |
def assistant(state: MessagesState):
|
| 213 |
"""Assistant node"""
|
| 214 |
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
|
@@ -217,86 +238,29 @@ def build_graph(provider: str = "google"):
|
|
| 217 |
|
| 218 |
def retriever(state: MessagesState):
|
| 219 |
query = state["messages"][-1].content
|
| 220 |
-
# Use the
|
| 221 |
-
similar_docs =
|
| 222 |
|
| 223 |
-
# The tool returns a list of Documents, so we need to process it
|
| 224 |
-
# Assuming the tool returns a list of documents, we take the first one
|
| 225 |
if similar_docs:
|
| 226 |
-
|
| 227 |
-
#
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
# Given the original `retriever` node, it expected `similar_question[0].page_content`.
|
| 231 |
-
# If `retriever_tool.invoke(query)` returns a list of Document objects,
|
| 232 |
-
# then `similar_docs[0].page_content` is correct.
|
| 233 |
-
# If it returns a string, we need to adapt.
|
| 234 |
-
# For now, let's assume it returns a list of Documents or a string that contains the answer.
|
| 235 |
-
|
| 236 |
-
# If retriever_tool returns a string directly (as per your tool definition):
|
| 237 |
-
# content = similar_docs # This would be the string output from the tool
|
| 238 |
-
|
| 239 |
-
# If retriever_tool returns a list of Document objects from its internal retriever:
|
| 240 |
-
# Let's assume the `retriever_tool` internally uses `vector_store.as_retriever().invoke(query)`
|
| 241 |
-
# which returns a list of `Document` objects.
|
| 242 |
-
# The `create_retriever_tool` wraps this, so `retriever_tool.invoke` will return a string
|
| 243 |
-
# that is the `page_content` of the retrieved documents.
|
| 244 |
-
|
| 245 |
-
# The original `retriever` node was using `vector_store.similarity_search` directly.
|
| 246 |
-
# Now `retriever_tool` is a LangChain tool.
|
| 247 |
-
# When `retriever_tool.invoke(query)` is called, it will return the formatted string
|
| 248 |
-
# from the `create_retriever_tool` definition.
|
| 249 |
-
# So, `similar_docs` will be a string.
|
| 250 |
-
|
| 251 |
-
# We need to parse the `similar_docs` string to extract the answer.
|
| 252 |
-
# The `Question_Search` tool description is "A tool to retrieve similar questions from a vector store and their answers."
|
| 253 |
-
# The `create_retriever_tool` automatically formats the output of the retriever.
|
| 254 |
-
# Let's assume the output string from `retriever_tool.invoke(query)` will look something like:
|
| 255 |
-
# "content='What is the capital of Sweden?' metadata={'source': 'internal', 'answer': 'Stockholm'}"
|
| 256 |
-
# We need to extract the 'answer' part.
|
| 257 |
-
|
| 258 |
-
# A more robust way would be to make the retriever node *call* the tool,
|
| 259 |
-
# and then the LLM decides if it wants to use the tool.
|
| 260 |
-
# However, your current graph structure has a dedicated "retriever" node
|
| 261 |
-
# that directly fetches and returns an AIMessage.
|
| 262 |
-
|
| 263 |
-
# Let's refine the retriever node to parse the output of the tool more robustly.
|
| 264 |
-
# The `create_retriever_tool` returns a string where documents are joined.
|
| 265 |
-
# We need to extract the content that would be the "answer".
|
| 266 |
-
|
| 267 |
-
# The dummy documents have `metadata={"source": "internal", "answer": "..."}`.
|
| 268 |
-
# The `create_retriever_tool` will return `doc.page_content` by default.
|
| 269 |
-
# So, `similar_docs` will contain the question itself.
|
| 270 |
-
# We need to ensure the retriever provides the *answer* not just the question.
|
| 271 |
-
|
| 272 |
-
# Let's adjust the `retriever` node to directly access the `vector_store`
|
| 273 |
-
# for `similarity_search` and then extract the answer from metadata,
|
| 274 |
-
# similar to your original implementation. This bypasses the tool wrapper
|
| 275 |
-
# for this specific node, ensuring we get the full Document object.
|
| 276 |
-
|
| 277 |
-
similar_doc = vector_store.similarity_search(query, k=1)[0]
|
| 278 |
-
|
| 279 |
-
# Check if an 'answer' is directly available in metadata
|
| 280 |
-
if "answer" in similar_doc.metadata:
|
| 281 |
-
answer = similar_doc.metadata["answer"]
|
| 282 |
elif "Final answer :" in similar_doc.page_content:
|
| 283 |
answer = similar_doc.page_content.split("Final answer :")[-1].strip()
|
| 284 |
else:
|
| 285 |
answer = similar_doc.page_content.strip() # Fallback to page_content if no explicit answer
|
| 286 |
|
|
|
|
|
|
|
| 287 |
return {"messages": [AIMessage(content=answer)]}
|
| 288 |
else:
|
| 289 |
-
# If no similar documents found, return an empty AIMessage or a message indicating no answer
|
| 290 |
return {"messages": [AIMessage(content="No similar questions found in the knowledge base.")]}
|
| 291 |
|
| 292 |
-
|
| 293 |
builder = StateGraph(MessagesState)
|
| 294 |
builder.add_node("retriever", retriever)
|
| 295 |
-
|
| 296 |
-
# Retriever ist Start und Endpunkt
|
| 297 |
builder.set_entry_point("retriever")
|
| 298 |
builder.set_finish_point("retriever")
|
| 299 |
|
| 300 |
-
# Compile graph
|
| 301 |
return builder.compile()
|
| 302 |
|
|
|
|
| 13 |
from langchain_core.messages import SystemMessage, HumanMessage
|
| 14 |
from langchain_core.tools import tool
|
| 15 |
from langchain.tools.retriever import create_retriever_tool
|
| 16 |
+
from langchain_community.vectorstores import Chroma
|
| 17 |
+
from langchain_core.documents import Document
|
| 18 |
+
import shutil
|
| 19 |
+
import pandas as pd # Ny import för pandas
|
| 20 |
+
import json # För att parsa metadata-kolumnen
|
| 21 |
|
| 22 |
load_dotenv()
|
| 23 |
|
|
|
|
| 124 |
# --- Start ChromaDB Setup ---
|
| 125 |
# Define the directory for ChromaDB persistence
|
| 126 |
CHROMA_DB_DIR = "./chroma_db"
|
| 127 |
+
CSV_FILE_PATH = "./supabase.docs.csv" # Path to your CSV file
|
| 128 |
|
| 129 |
# Build embeddings (this remains the same)
|
| 130 |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
|
| 131 |
|
| 132 |
# Initialize ChromaDB
|
| 133 |
+
# If the directory exists and contains data, load the existing vector store.
|
| 134 |
+
# Otherwise, create a new one and add documents from the CSV file.
|
| 135 |
if os.path.exists(CHROMA_DB_DIR) and os.listdir(CHROMA_DB_DIR):
|
| 136 |
print(f"Loading existing ChromaDB from {CHROMA_DB_DIR}")
|
| 137 |
vector_store = Chroma(
|
|
|
|
| 139 |
embedding_function=embeddings
|
| 140 |
)
|
| 141 |
else:
|
| 142 |
+
print(f"Creating new ChromaDB at {CHROMA_DB_DIR} and loading documents from {CSV_FILE_PATH}.")
|
| 143 |
# Ensure the directory is clean before creating new
|
| 144 |
if os.path.exists(CHROMA_DB_DIR):
|
| 145 |
shutil.rmtree(CHROMA_DB_DIR)
|
| 146 |
os.makedirs(CHROMA_DB_DIR)
|
| 147 |
|
| 148 |
+
# Load data from the CSV file
|
| 149 |
+
if not os.path.exists(CSV_FILE_PATH):
|
| 150 |
+
raise FileNotFoundError(f"CSV file not found at {CSV_FILE_PATH}. Please ensure it's in the root directory.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
+
df = pd.read_csv(CSV_FILE_PATH)
|
| 153 |
+
documents = []
|
| 154 |
+
for index, row in df.iterrows():
|
| 155 |
+
content = row["content"]
|
| 156 |
+
|
| 157 |
+
# Extract the question part from the content
|
| 158 |
+
# Assuming the question is everything before "Final answer :"
|
| 159 |
+
question_part = content.split("Final answer :")[0].strip()
|
| 160 |
+
|
| 161 |
+
# Extract the final answer part from the content
|
| 162 |
+
final_answer_part = content.split("Final answer :")[-1].strip() if "Final answer :" in content else ""
|
| 163 |
+
|
| 164 |
+
# Parse the metadata string into a dictionary
|
| 165 |
+
# The metadata column might be stored as a string representation of a dictionary
|
| 166 |
+
try:
|
| 167 |
+
metadata = json.loads(row["metadata"].replace("'", "\"")) # Replace single quotes for valid JSON
|
| 168 |
+
except json.JSONDecodeError:
|
| 169 |
+
metadata = {} # Fallback if parsing fails
|
| 170 |
+
|
| 171 |
+
# Add the extracted final answer to the metadata for easy retrieval
|
| 172 |
+
metadata["final_answer"] = final_answer_part
|
| 173 |
+
|
| 174 |
+
# Create a Document object. The page_content should be the question for similarity search.
|
| 175 |
+
# The answer will be in metadata.
|
| 176 |
+
documents.append(Document(page_content=question_part, metadata=metadata))
|
| 177 |
+
|
| 178 |
+
if not documents:
|
| 179 |
+
print("No documents loaded from CSV. ChromaDB will be empty.")
|
| 180 |
+
# Create an empty ChromaDB if no documents are found
|
| 181 |
+
vector_store = Chroma(
|
| 182 |
+
persist_directory=CHROMA_DB_DIR,
|
| 183 |
+
embedding_function=embeddings
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
vector_store = Chroma.from_documents(
|
| 187 |
+
documents=documents,
|
| 188 |
+
embedding=embeddings,
|
| 189 |
+
persist_directory=CHROMA_DB_DIR
|
| 190 |
+
)
|
| 191 |
+
vector_store.persist() # Save the new vector store to disk
|
| 192 |
+
print(f"ChromaDB initialized and persisted with {len(documents)} documents from CSV.")
|
| 193 |
|
| 194 |
# Create retriever tool using the Chroma vector store
|
| 195 |
+
retriever_tool = create_retriever_tool(
|
| 196 |
retriever=vector_store.as_retriever(),
|
| 197 |
+
name="Question_Search",
|
| 198 |
+
description="A tool to retrieve similar questions from a vector store. The retrieved document's metadata contains the 'final_answer' to the question.",
|
| 199 |
)
|
| 200 |
|
| 201 |
# Add the new retriever tool to your list of tools
|
|
|
|
| 208 |
wiki_search,
|
| 209 |
web_search,
|
| 210 |
arvix_search,
|
| 211 |
+
retriever_tool,
|
| 212 |
]
|
| 213 |
|
| 214 |
# Build graph function
|
| 215 |
def build_graph(provider: str = "google"):
|
| 216 |
"""Build the graph"""
|
|
|
|
| 217 |
if provider == "google":
|
|
|
|
| 218 |
llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
|
| 219 |
elif provider == "groq":
|
| 220 |
+
llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
|
|
|
|
| 221 |
elif provider == "huggingface":
|
|
|
|
| 222 |
llm = ChatHuggingFace(
|
| 223 |
llm=HuggingFaceEndpoint(
|
| 224 |
url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
|
|
|
|
| 227 |
)
|
| 228 |
else:
|
| 229 |
raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
|
| 230 |
+
|
| 231 |
llm_with_tools = llm.bind_tools(tools)
|
| 232 |
|
|
|
|
| 233 |
def assistant(state: MessagesState):
|
| 234 |
"""Assistant node"""
|
| 235 |
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
|
|
|
| 238 |
|
| 239 |
def retriever(state: MessagesState):
|
| 240 |
query = state["messages"][-1].content
|
| 241 |
+
# Use the vector_store directly for similarity search to get the full Document object
|
| 242 |
+
similar_docs = vector_store.similarity_search(query, k=1)
|
| 243 |
|
|
|
|
|
|
|
| 244 |
if similar_docs:
|
| 245 |
+
similar_doc = similar_docs[0]
|
| 246 |
+
# Prioritize 'final_answer' from metadata, then check page_content
|
| 247 |
+
if "final_answer" in similar_doc.metadata and similar_doc.metadata["final_answer"]:
|
| 248 |
+
answer = similar_doc.metadata["final_answer"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
elif "Final answer :" in similar_doc.page_content:
|
| 250 |
answer = similar_doc.page_content.split("Final answer :")[-1].strip()
|
| 251 |
else:
|
| 252 |
answer = similar_doc.page_content.strip() # Fallback to page_content if no explicit answer
|
| 253 |
|
| 254 |
+
# The system prompt expects "FINAL ANSWER: [ANSWER]".
|
| 255 |
+
# We should return the extracted answer directly, as the prompt handles the formatting.
|
| 256 |
return {"messages": [AIMessage(content=answer)]}
|
| 257 |
else:
|
|
|
|
| 258 |
return {"messages": [AIMessage(content="No similar questions found in the knowledge base.")]}
|
| 259 |
|
|
|
|
| 260 |
builder = StateGraph(MessagesState)
|
| 261 |
builder.add_node("retriever", retriever)
|
|
|
|
|
|
|
| 262 |
builder.set_entry_point("retriever")
|
| 263 |
builder.set_finish_point("retriever")
|
| 264 |
|
|
|
|
| 265 |
return builder.compile()
|
| 266 |
|