Nottybro commited on
Commit
ea894c2
·
verified ·
1 Parent(s): aa7d131

fix: acra.py — switch to google-genai + gemini-embedding-001

Browse files
Files changed (1) hide show
  1. acra.py +82 -64
acra.py CHANGED
@@ -1,21 +1,36 @@
1
- import os
2
- import google.generativeai as genai
 
3
  from db import supabase
4
  from classifier_inference import classify_query
5
  from typing import List
6
 
7
- genai.configure(api_key=os.environ["GEMINI_API_KEY"])
8
- EMBED_MODEL = "models/text-embedding-004"
9
  GEN_MODEL = "gemma-3-27b-it"
10
  DEPTH = {0: 0, 1: 3, 2: 6, 3: 10}
11
 
12
  def embed_texts(texts):
13
- return [genai.embed_content(model=EMBED_MODEL, content=t,
14
- task_type="retrieval_document")["embedding"] for t in texts]
 
 
 
 
 
 
 
15
 
16
  def embed_query(q):
17
- return genai.embed_content(model=EMBED_MODEL, content=q,
18
- task_type="retrieval_query")["embedding"]
 
 
 
 
 
 
 
19
 
20
  def adaptive_chunk(text, max_tok=512):
21
  paras = [p.strip() for p in text.split("\n\n") if p.strip()]
@@ -29,96 +44,99 @@ def adaptive_chunk(text, max_tok=512):
29
  if cur: chunks.append(cur)
30
  return chunks or [text]
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def decompose(query):
33
- m = genai.GenerativeModel(GEN_MODEL)
34
- r = m.generate_content(
35
- f"Decompose into 2-4 simpler sub-queries. "
36
- f"Return numbered list only.\n\nQuery: {query}")
37
- lines = [l.strip().lstrip("1234567890.). ")
38
- for l in r.text.strip().split("\n") if l.strip()]
39
  return lines[:4] or [query]
40
 
41
  def compress(query, chunks):
42
- m = genai.GenerativeModel(GEN_MODEL)
43
  out = []
44
  for c in chunks:
45
- r = m.generate_content(
46
- f"Extract only sentences relevant to the query. "
47
- f"Return empty if none.\n\nQuery: {query}\nChunk: {c}")
48
  if r.text.strip(): out.append(r.text.strip())
49
  return out
50
 
51
  def vsearch(query, namespace, user_id, k):
52
  return (supabase.rpc("match_documents", {
53
- "query_embedding": embed_query(query),
54
- "match_count": k,
55
  "filter_namespace": namespace,
56
- "filter_user_id": user_id
57
  }).execute().data or [])
58
 
59
  PROMPTS = {
60
- 1: "Answer using ONLY the context below. Be concise.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
61
- 2: "Synthesize the context to answer. Think step by step.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
62
- 3: "Use chain-of-thought to answer this complex question.\nAddress each aspect. Note any gaps.\n\nContext:\n{ctx}\n\nQuestion: {q}\nReasoning and answer:",
 
 
 
 
 
 
63
  }
64
 
65
  async def ingest_pipeline(texts, metadata, namespace, user_id):
66
  chunks, meta = [], []
67
  for i, t in enumerate(texts):
68
  for j, c in enumerate(adaptive_chunk(t)):
69
- chunks.append(c)
70
- meta.append({**metadata[i], "source_index": i, "chunk_index": j})
71
- rows = [{"content": c, "embedding": e, "metadata": m,
72
- "namespace": namespace, "user_id": user_id}
73
  for c, e, m in zip(chunks, embed_texts(chunks), meta)]
74
  for i in range(0, len(rows), 50):
75
  supabase.table("documents").insert(rows[i:i+50]).execute()
76
  return len(chunks)
77
 
78
- async def query_pipeline(query, namespace, top_k, rerank, user_id):
79
- cls = classify_query(query)
80
- level = cls["level"]
81
- k = DEPTH[level]
82
- model = genai.GenerativeModel(GEN_MODEL)
83
-
84
  if level == 0:
85
- r = model.generate_content(
86
- f"Answer concisely from your knowledge:\n\n{query}")
87
- return {"answer": r.text.strip(), "sources": [], "complexity": cls}
88
-
89
  hits = []
90
  if level == 3:
91
  seen = set()
92
  for sq in decompose(query):
93
  for h in vsearch(sq, namespace, user_id, 4):
94
- if h["id"] not in seen:
95
- seen.add(h["id"]); hits.append(h)
96
  else:
97
  hits = vsearch(query, namespace, user_id, k)
98
-
99
- if not hits:
100
- return {"answer": "No relevant documents found. Ingest some first.",
101
- "sources": [], "complexity": cls}
102
-
103
- chunks = [h["content"] for h in hits]
104
- if rerank and level >= 2:
105
- chunks = [c for c in compress(query, chunks) if c.strip()]
106
-
107
- ctx = "\n\n---\n\n".join(chunks[:k])
108
- r = model.generate_content(PROMPTS[level].format(ctx=ctx, q=query))
109
-
110
- return {
111
- "answer": r.text.strip(),
112
- "sources": [{"content": h["content"][:200],
113
- "metadata": h.get("metadata", {}),
114
- "score": h.get("similarity", 0)}
115
- for h in hits[:len(chunks)]],
116
- "complexity": cls
117
- }
118
 
119
  async def run_acra_pipeline(mode, **kw):
120
  if mode == "ingest":
121
- return await ingest_pipeline(kw["texts"], kw["metadata"],
122
- kw["namespace"], kw["user_id"])
123
- return await query_pipeline(kw["query"], kw["namespace"],
124
- kw["top_k"], kw["rerank"], kw["user_id"])
 
1
+ import os, httpx
2
+ from google import genai
3
+ from google.genai import types
4
  from db import supabase
5
  from classifier_inference import classify_query
6
  from typing import List
7
 
8
+ client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
9
+ EMBED_MODEL = "gemini-embedding-001" # replaces shut-down text-embedding-004
10
  GEN_MODEL = "gemma-3-27b-it"
11
  DEPTH = {0: 0, 1: 3, 2: 6, 3: 10}
12
 
13
  def embed_texts(texts):
14
+ result = client.models.embed_content(
15
+ model=EMBED_MODEL,
16
+ contents=texts,
17
+ config=types.EmbedContentConfig(
18
+ task_type="RETRIEVAL_DOCUMENT",
19
+ output_dimensionality=768 # keeps existing Supabase vector(768) schema
20
+ )
21
+ )
22
+ return [e.values for e in result.embeddings]
23
 
24
  def embed_query(q):
25
+ result = client.models.embed_content(
26
+ model=EMBED_MODEL,
27
+ contents=[q],
28
+ config=types.EmbedContentConfig(
29
+ task_type="RETRIEVAL_QUERY",
30
+ output_dimensionality=768
31
+ )
32
+ )
33
+ return result.embeddings[0].values
34
 
35
  def adaptive_chunk(text, max_tok=512):
36
  paras = [p.strip() for p in text.split("\n\n") if p.strip()]
 
44
  if cur: chunks.append(cur)
45
  return chunks or [text]
46
 
47
+ def web_search(query, max_results=5):
48
+ try:
49
+ r = httpx.get("https://api.duckduckgo.com/",
50
+ params={"q": query, "format": "json", "no_html": "1", "skip_disambig": "1"},
51
+ headers={"User-Agent": "ACRA/1.0"}, timeout=10.0)
52
+ data, results = r.json(), []
53
+ if data.get("AbstractText"):
54
+ results.append({"title": data.get("Heading","Web"), "snippet": data["AbstractText"], "url": data.get("AbstractURL","")})
55
+ for t in data.get("RelatedTopics", [])[:max_results]:
56
+ if isinstance(t, dict) and t.get("Text"):
57
+ results.append({"title": t.get("Name","Web"), "snippet": t["Text"], "url": t.get("FirstURL","")})
58
+ return results[:max_results]
59
+ except: return []
60
+
61
  def decompose(query):
62
+ r = client.models.generate_content(model=GEN_MODEL,
63
+ contents=f"Decompose into 2-4 simpler sub-queries. Numbered list only.\n\nQuery: {query}")
64
+ lines = [l.strip().lstrip("1234567890.). ") for l in r.text.strip().split("\n") if l.strip()]
 
 
 
65
  return lines[:4] or [query]
66
 
67
  def compress(query, chunks):
 
68
  out = []
69
  for c in chunks:
70
+ r = client.models.generate_content(model=GEN_MODEL,
71
+ contents=f"Extract only sentences relevant to the query. Return empty if none.\n\nQuery: {query}\nChunk: {c}")
 
72
  if r.text.strip(): out.append(r.text.strip())
73
  return out
74
 
75
  def vsearch(query, namespace, user_id, k):
76
  return (supabase.rpc("match_documents", {
77
+ "query_embedding": embed_query(query),
78
+ "match_count": k,
79
  "filter_namespace": namespace,
80
+ "filter_user_id": user_id,
81
  }).execute().data or [])
82
 
83
  PROMPTS = {
84
+ 0: "Answer from your knowledge:\n\n{q}",
85
+ 1: "Answer using ONLY the context. Be concise.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
86
+ 2: "Synthesize the context step by step.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
87
+ 3: "Use chain-of-thought reasoning.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
88
+ }
89
+ WEB_PROMPTS = {
90
+ 1: "Answer using these web results:\n\n{ctx}\n\nQuestion: {q}\nAnswer:",
91
+ 2: "Synthesize these web results:\n\n{ctx}\n\nQuestion: {q}\nAnswer:",
92
+ 3: "Reason through this using web results:\n\n{ctx}\n\nQuestion: {q}\nAnswer:",
93
  }
94
 
95
  async def ingest_pipeline(texts, metadata, namespace, user_id):
96
  chunks, meta = [], []
97
  for i, t in enumerate(texts):
98
  for j, c in enumerate(adaptive_chunk(t)):
99
+ chunks.append(c); meta.append({**metadata[i], "source_index": i, "chunk_index": j})
100
+ rows = [{"content": c, "embedding": e, "metadata": m, "namespace": namespace, "user_id": user_id}
 
 
101
  for c, e, m in zip(chunks, embed_texts(chunks), meta)]
102
  for i in range(0, len(rows), 50):
103
  supabase.table("documents").insert(rows[i:i+50]).execute()
104
  return len(chunks)
105
 
106
+ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False):
107
+ cls = classify_query(query); level = cls["level"]; k = DEPTH[level]
 
 
 
 
108
  if level == 0:
109
+ r = client.models.generate_content(model=GEN_MODEL, contents=PROMPTS[0].format(q=query))
110
+ return {"answer": r.text.strip(), "sources": [], "complexity": cls, "retrieval_source": "model_knowledge"}
 
 
111
  hits = []
112
  if level == 3:
113
  seen = set()
114
  for sq in decompose(query):
115
  for h in vsearch(sq, namespace, user_id, 4):
116
+ if h["id"] not in seen: seen.add(h["id"]); hits.append(h)
 
117
  else:
118
  hits = vsearch(query, namespace, user_id, k)
119
+ web_hits, retrieval_source = [], "local"
120
+ if use_web or not hits:
121
+ web_hits = web_search(query, max_results=k)
122
+ if not hits and not web_hits:
123
+ return {"answer": "Nothing found locally or on the web.", "sources": [], "complexity": cls, "retrieval_source": "none"}
124
+ retrieval_source = "web" if not hits else "local_and_web"
125
+ all_chunks, all_sources = [], []
126
+ if hits:
127
+ lc = [h["content"] for h in hits]
128
+ if rerank and level >= 2: lc = [c for c in compress(query, lc) if c.strip()]
129
+ all_chunks += lc[:k]
130
+ all_sources += [{"content": h["content"][:200], "metadata": h.get("metadata",{}), "score": h.get("similarity",0), "source": "local"} for h in hits[:len(lc)]]
131
+ if web_hits:
132
+ all_chunks += [f"{h['title']}: {h['snippet']}" for h in web_hits]
133
+ all_sources += [{"content": h["snippet"][:200], "metadata": {"title": h["title"], "url": h["url"]}, "score": 1.0, "source": "web"} for h in web_hits]
134
+ ctx = "\n\n---\n\n".join(all_chunks)
135
+ prompt = (WEB_PROMPTS if retrieval_source=="web" else PROMPTS).get(level, PROMPTS[level])
136
+ r = client.models.generate_content(model=GEN_MODEL, contents=prompt.format(ctx=ctx, q=query))
137
+ return {"answer": r.text.strip(), "sources": all_sources, "complexity": cls, "retrieval_source": retrieval_source}
 
138
 
139
  async def run_acra_pipeline(mode, **kw):
140
  if mode == "ingest":
141
+ return await ingest_pipeline(kw["texts"], kw["metadata"], kw["namespace"], kw["user_id"])
142
+ return await query_pipeline(kw["query"], kw["namespace"], kw["top_k"], kw["rerank"], kw["user_id"], use_web=kw.get("use_web", False))