Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
|
@@ -1,17 +1,34 @@
|
|
| 1 |
# agent.py
|
| 2 |
import os
|
| 3 |
-
from supabase import create_client
|
| 4 |
from sentence_transformers import SentenceTransformers
|
| 5 |
from serpapi import GoogleSearch
|
|
|
|
|
|
|
| 6 |
from langgraph import Graph, LLM, tool #or other graph library
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
# ----Supabase setup----
|
| 9 |
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
| 10 |
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY")
|
| 11 |
EMBED_MODEL_ID = os.getenv("HF_EMBEDDING_MODEL")
|
| 12 |
|
| 13 |
|
| 14 |
-
sb_client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
| 15 |
embedder = SentenceTransformers(EMBED_MODEL_ID)
|
| 16 |
|
| 17 |
# 1) Define tools
|
|
@@ -27,26 +44,41 @@ def calculator(expr: str) -> str:
|
|
| 27 |
# @tool
|
| 28 |
# def web_search(query:str) -> str:
|
| 29 |
# ...
|
| 30 |
-
@tool
|
| 31 |
-
def retrieve_docs(query: str, k: int = 3) -> str:
|
| 32 |
-
"""
|
| 33 |
-
Fetch tpo-k docs from Supabase vector store.
|
| 34 |
-
Returns the concatenated text.
|
| 35 |
-
"""
|
| 36 |
# --- embed the query
|
| 37 |
-
q_emb = embedder.encode(query).tolist()
|
| 38 |
|
| 39 |
# --- query the embedding table
|
| 40 |
-
response = (
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
|
| 47 |
# ---- concatenate the content field
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
SERPAPI_KEY = os.getenv("SERPAPY_KEY")
|
| 52 |
# ---- web_search tool
|
|
|
|
| 1 |
# agent.py
|
| 2 |
import os
|
| 3 |
+
#from supabase import create_client
|
| 4 |
from sentence_transformers import SentenceTransformers
|
| 5 |
from serpapi import GoogleSearch
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import faiss
|
| 8 |
from langgraph import Graph, LLM, tool #or other graph library
|
| 9 |
|
| 10 |
+
|
| 11 |
+
# ─── 1) Load & embed all documents at startup ───
|
| 12 |
+
# 1a) Read CSV of docs
|
| 13 |
+
df = pd.read_csv("documents.csv")
|
| 14 |
+
DOCS = df["content"].tolist()
|
| 15 |
+
|
| 16 |
+
# 1b) Create an embedding model
|
| 17 |
+
EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
|
| 18 |
+
|
| 19 |
+
# 1c) Compute embeddings (float32) and build FAISS index
|
| 20 |
+
EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32")
|
| 21 |
+
INDEX = faiss.IndexFlatL2(EMBS.shape[1])
|
| 22 |
+
INDEX.add(EMBS)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
# ----Supabase setup----
|
| 26 |
SUPABASE_URL = os.getenv("SUPABASE_URL")
|
| 27 |
SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY")
|
| 28 |
EMBED_MODEL_ID = os.getenv("HF_EMBEDDING_MODEL")
|
| 29 |
|
| 30 |
|
| 31 |
+
#sb_client = create_client(SUPABASE_URL, SUPABASE_KEY)
|
| 32 |
embedder = SentenceTransformers(EMBED_MODEL_ID)
|
| 33 |
|
| 34 |
# 1) Define tools
|
|
|
|
| 44 |
# @tool
|
| 45 |
# def web_search(query:str) -> str:
|
| 46 |
# ...
|
| 47 |
+
#@tool
|
| 48 |
+
#def retrieve_docs(query: str, k: int = 3) -> str:
|
| 49 |
+
#"""
|
| 50 |
+
#Fetch tpo-k docs from Supabase vector store.
|
| 51 |
+
#Returns the concatenated text.
|
| 52 |
+
#"""
|
| 53 |
# --- embed the query
|
| 54 |
+
#q_emb = embedder.encode(query).tolist()
|
| 55 |
|
| 56 |
# --- query the embedding table
|
| 57 |
+
#response = (
|
| 58 |
+
# sb_client
|
| 59 |
+
# .rpc("match_documents", {"query_embedding": q_emb, "match_count": k})
|
| 60 |
+
# .execute()
|
| 61 |
+
# )
|
| 62 |
+
# rows = response.data
|
| 63 |
|
| 64 |
# ---- concatenate the content field
|
| 65 |
+
# docs = [row["content"] for row in rows]
|
| 66 |
+
# return "\n\n---\n\n".join(docs)
|
| 67 |
+
|
| 68 |
+
@tool
|
| 69 |
+
def retrieve_docs(query: str, k: int = 3) -> str:
|
| 70 |
+
"""
|
| 71 |
+
k-NN search over our in-memory FAISS index.
|
| 72 |
+
Returns the top-k documents concatenated.
|
| 73 |
+
"""
|
| 74 |
+
# 1) Embed the query
|
| 75 |
+
q_emb = EMBEDDER.encode([query]).astype("float32")
|
| 76 |
+
# 2) Search FAISS
|
| 77 |
+
D, I = INDEX.search(q_emb, k)
|
| 78 |
+
# 3) Gather and return the texts
|
| 79 |
+
hits = [DOCS[i] for i in I[0]]
|
| 80 |
+
return "\n\n---\n\n".join(hits)
|
| 81 |
+
|
| 82 |
|
| 83 |
SERPAPI_KEY = os.getenv("SERPAPY_KEY")
|
| 84 |
# ---- web_search tool
|