AC-Angelo93 commited on
Commit
8d47cbc
·
verified ·
1 Parent(s): e4b06ec

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +49 -17
agent.py CHANGED
@@ -1,17 +1,34 @@
1
  # agent.py
2
  import os
3
- from supabase import create_client
4
  from sentence_transformers import SentenceTransformers
5
  from serpapi import GoogleSearch
 
 
6
  from langgraph import Graph, LLM, tool #or other graph library
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # ----Supabase setup----
9
  SUPABASE_URL = os.getenv("SUPABASE_URL")
10
  SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY")
11
  EMBED_MODEL_ID = os.getenv("HF_EMBEDDING_MODEL")
12
 
13
 
14
- sb_client = create_client(SUPABASE_URL, SUPABASE_KEY)
15
  embedder = SentenceTransformers(EMBED_MODEL_ID)
16
 
17
  # 1) Define tools
@@ -27,26 +44,41 @@ def calculator(expr: str) -> str:
27
  # @tool
28
  # def web_search(query:str) -> str:
29
  # ...
30
- @tool
31
- def retrieve_docs(query: str, k: int = 3) -> str:
32
- """
33
- Fetch tpo-k docs from Supabase vector store.
34
- Returns the concatenated text.
35
- """
36
  # --- embed the query
37
- q_emb = embedder.encode(query).tolist()
38
 
39
  # --- query the embedding table
40
- response = (
41
- sb_client
42
- .rpc("match_documents", {"query_embedding": q_emb, "match_count": k})
43
- .execute()
44
- )
45
- rows = response.data
46
 
47
  # ---- concatenate the content field
48
- docs = [row["content"] for row in rows]
49
- return "\n\n---\n\n".join(docs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  SERPAPI_KEY = os.getenv("SERPAPY_KEY")
52
  # ---- web_search tool
 
1
  # agent.py
2
  import os
3
+ #from supabase import create_client
4
  from sentence_transformers import SentenceTransformers
5
  from serpapi import GoogleSearch
6
+ import pandas as pd
7
+ import faiss
8
  from langgraph import Graph, LLM, tool #or other graph library
9
 
10
+
11
+ # ─── 1) Load & embed all documents at startup ───
12
+ # 1a) Read CSV of docs
13
+ df = pd.read_csv("documents.csv")
14
+ DOCS = df["content"].tolist()
15
+
16
+ # 1b) Create an embedding model
17
+ EMBEDDER = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
18
+
19
+ # 1c) Compute embeddings (float32) and build FAISS index
20
+ EMBS = EMBEDDER.encode(DOCS, show_progress_bar=True).astype("float32")
21
+ INDEX = faiss.IndexFlatL2(EMBS.shape[1])
22
+ INDEX.add(EMBS)
23
+
24
+
25
  # ----Supabase setup----
26
  SUPABASE_URL = os.getenv("SUPABASE_URL")
27
  SUPABASE_KEY = os.getenv("SUPABASE_SERVICE_KEY")
28
  EMBED_MODEL_ID = os.getenv("HF_EMBEDDING_MODEL")
29
 
30
 
31
+ #sb_client = create_client(SUPABASE_URL, SUPABASE_KEY)
32
  embedder = SentenceTransformers(EMBED_MODEL_ID)
33
 
34
  # 1) Define tools
 
44
  # @tool
45
  # def web_search(query:str) -> str:
46
  # ...
47
+ #@tool
48
+ #def retrieve_docs(query: str, k: int = 3) -> str:
49
+ #"""
50
+ #Fetch tpo-k docs from Supabase vector store.
51
+ #Returns the concatenated text.
52
+ #"""
53
  # --- embed the query
54
+ #q_emb = embedder.encode(query).tolist()
55
 
56
  # --- query the embedding table
57
+ #response = (
58
+ # sb_client
59
+ # .rpc("match_documents", {"query_embedding": q_emb, "match_count": k})
60
+ # .execute()
61
+ # )
62
+ # rows = response.data
63
 
64
  # ---- concatenate the content field
65
+ # docs = [row["content"] for row in rows]
66
+ # return "\n\n---\n\n".join(docs)
67
+
68
+ @tool
69
+ def retrieve_docs(query: str, k: int = 3) -> str:
70
+ """
71
+ k-NN search over our in-memory FAISS index.
72
+ Returns the top-k documents concatenated.
73
+ """
74
+ # 1) Embed the query
75
+ q_emb = EMBEDDER.encode([query]).astype("float32")
76
+ # 2) Search FAISS
77
+ D, I = INDEX.search(q_emb, k)
78
+ # 3) Gather and return the texts
79
+ hits = [DOCS[i] for i in I[0]]
80
+ return "\n\n---\n\n".join(hits)
81
+
82
 
83
  SERPAPI_KEY = os.getenv("SERPAPY_KEY")
84
  # ---- web_search tool