Nottybro commited on
Commit
92d9d21
Β·
verified Β·
1 Parent(s): 9d77494

fix: L0 always retrieves then lets Gemma decide context vs knowledge

Browse files
Files changed (1) hide show
  1. acra.py +42 -33
acra.py CHANGED
@@ -5,22 +5,20 @@ from db import supabase
5
  from classifier_inference import classify_query
6
  from typing import List
7
 
8
- client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
9
  EMBED_MODEL = "gemini-embedding-001"
10
  GEN_MODEL = "gemma-3-27b-it"
11
  DEPTH = {0: 3, 1: 3, 2: 6, 3: 10}
12
 
13
  def embed_texts(texts):
14
- result = client.models.embed_content(
15
- model=EMBED_MODEL, contents=texts,
16
  config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT", output_dimensionality=768))
17
- return [e.values for e in result.embeddings]
18
 
19
  def embed_query(q):
20
- result = client.models.embed_content(
21
- model=EMBED_MODEL, contents=[q],
22
  config=types.EmbedContentConfig(task_type="RETRIEVAL_QUERY", output_dimensionality=768))
23
- return result.embeddings[0].values
24
 
25
  def adaptive_chunk(text, max_tok=512):
26
  paras = [p.strip() for p in text.split("\n\n") if p.strip()]
@@ -39,13 +37,13 @@ def web_search(query, max_results=5):
39
  r = httpx.get("https://api.duckduckgo.com/",
40
  params={"q": query, "format": "json", "no_html": "1", "skip_disambig": "1"},
41
  headers={"User-Agent": "ACRA/1.0"}, timeout=10.0)
42
- data, results = r.json(), []
43
  if data.get("AbstractText"):
44
- results.append({"title": data.get("Heading","Web"), "snippet": data["AbstractText"], "url": data.get("AbstractURL","")})
45
- for t in data.get("RelatedTopics", [])[:max_results]:
46
- if isinstance(t, dict) and t.get("Text"):
47
- results.append({"title": t.get("Name","Web"), "snippet": t["Text"], "url": t.get("FirstURL","")})
48
- return results[:max_results]
49
  except: return []
50
 
51
  def decompose(query):
@@ -71,7 +69,6 @@ def vsearch(query, namespace, user_id, k):
71
  }).execute().data or [])
72
 
73
  PROMPTS = {
74
- 0: "Answer this from your knowledge:\n\n{q}",
75
  1: "Answer using ONLY the context. Be concise.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
76
  2: "Synthesize the context step by step.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
77
  3: "Use chain-of-thought reasoning.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
@@ -86,7 +83,8 @@ async def ingest_pipeline(texts, metadata, namespace, user_id):
86
  chunks, meta = [], []
87
  for i, t in enumerate(texts):
88
  for j, c in enumerate(adaptive_chunk(t)):
89
- chunks.append(c); meta.append({**metadata[i], "source_index": i, "chunk_index": j})
 
90
  rows = [{"content": c, "embedding": e, "metadata": m, "namespace": namespace, "user_id": user_id}
91
  for c, e, m in zip(chunks, embed_texts(chunks), meta)]
92
  for i in range(0, len(rows), 50):
@@ -97,27 +95,35 @@ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False
97
  cls = classify_query(query)
98
  level = cls["level"]
99
  k = DEPTH[level]
100
- model = client
101
 
102
- # ── L0: try docs first (similarity > 0.75), fall back to model knowledge
 
 
103
  if level == 0:
104
- l0_hits = vsearch(query, namespace, user_id, 3)
105
- strong_hits = [h for h in l0_hits if h.get("similarity", 0) > 0.75]
106
- if strong_hits:
107
- ctx = "\n\n---\n\n".join(h["content"] for h in strong_hits)
108
  r = client.models.generate_content(model=GEN_MODEL,
109
- contents=f"Answer using ONLY the context. Be concise.\n\nContext:\n{ctx}\n\nQuestion: {query}\nAnswer:")
 
 
 
 
 
110
  return {
111
  "answer": r.text.strip(),
112
- "sources": [{"content": h["content"][:200], "metadata": h.get("metadata", {}),
113
- "score": h.get("similarity", 0), "source": "local"} for h in strong_hits],
 
114
  "complexity": cls,
115
- "retrieval_source": "local"
116
  }
117
- r = client.models.generate_content(model=GEN_MODEL, contents=PROMPTS[0].format(q=query))
 
 
118
  return {"answer": r.text.strip(), "sources": [], "complexity": cls, "retrieval_source": "model_knowledge"}
119
 
120
- # ── L1-L3: standard retrieval
121
  hits = []
122
  if level == 3:
123
  seen = set()
@@ -131,7 +137,8 @@ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False
131
  if use_web or not hits:
132
  web_hits = web_search(query, max_results=k)
133
  if not hits and not web_hits:
134
- return {"answer": "Nothing found locally or on the web.", "sources": [], "complexity": cls, "retrieval_source": "none"}
 
135
  retrieval_source = "web" if not hits else "local_and_web"
136
 
137
  all_chunks, all_sources = [], []
@@ -139,19 +146,21 @@ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False
139
  lc = [h["content"] for h in hits]
140
  if rerank and level >= 2: lc = [c for c in compress(query, lc) if c.strip()]
141
  all_chunks += lc[:k]
142
- all_sources += [{"content": h["content"][:200], "metadata": h.get("metadata", {}),
143
- "score": h.get("similarity", 0), "source": "local"} for h in hits[:len(lc)]]
144
  if web_hits:
145
  all_chunks += [f"{h['title']}: {h['snippet']}" for h in web_hits]
146
  all_sources += [{"content": h["snippet"][:200], "metadata": {"title": h["title"], "url": h["url"]},
147
  "score": 1.0, "source": "web"} for h in web_hits]
148
 
149
  ctx = "\n\n---\n\n".join(all_chunks)
150
- prompt = (WEB_PROMPTS if retrieval_source == "web" else PROMPTS).get(level, PROMPTS[level])
151
  r = client.models.generate_content(model=GEN_MODEL, contents=prompt.format(ctx=ctx, q=query))
152
- return {"answer": r.text.strip(), "sources": all_sources, "complexity": cls, "retrieval_source": retrieval_source}
 
153
 
154
  async def run_acra_pipeline(mode, **kw):
155
  if mode == "ingest":
156
  return await ingest_pipeline(kw["texts"], kw["metadata"], kw["namespace"], kw["user_id"])
157
- return await query_pipeline(kw["query"], kw["namespace"], kw["top_k"], kw["rerank"], kw["user_id"], use_web=kw.get("use_web", False))
 
 
5
  from classifier_inference import classify_query
6
  from typing import List
7
 
8
+ client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
9
  EMBED_MODEL = "gemini-embedding-001"
10
  GEN_MODEL = "gemma-3-27b-it"
11
  DEPTH = {0: 3, 1: 3, 2: 6, 3: 10}
12
 
13
  def embed_texts(texts):
14
+ r = client.models.embed_content(model=EMBED_MODEL, contents=texts,
 
15
  config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT", output_dimensionality=768))
16
+ return [e.values for e in r.embeddings]
17
 
18
  def embed_query(q):
19
+ r = client.models.embed_content(model=EMBED_MODEL, contents=[q],
 
20
  config=types.EmbedContentConfig(task_type="RETRIEVAL_QUERY", output_dimensionality=768))
21
+ return r.embeddings[0].values
22
 
23
  def adaptive_chunk(text, max_tok=512):
24
  paras = [p.strip() for p in text.split("\n\n") if p.strip()]
 
37
  r = httpx.get("https://api.duckduckgo.com/",
38
  params={"q": query, "format": "json", "no_html": "1", "skip_disambig": "1"},
39
  headers={"User-Agent": "ACRA/1.0"}, timeout=10.0)
40
+ data, out = r.json(), []
41
  if data.get("AbstractText"):
42
+ out.append({"title": data.get("Heading","Web"), "snippet": data["AbstractText"], "url": data.get("AbstractURL","")})
43
+ for t in data.get("RelatedTopics",[])[:max_results]:
44
+ if isinstance(t,dict) and t.get("Text"):
45
+ out.append({"title": t.get("Name","Web"), "snippet": t["Text"], "url": t.get("FirstURL","")})
46
+ return out[:max_results]
47
  except: return []
48
 
49
  def decompose(query):
 
69
  }).execute().data or [])
70
 
71
  PROMPTS = {
 
72
  1: "Answer using ONLY the context. Be concise.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
73
  2: "Synthesize the context step by step.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
74
  3: "Use chain-of-thought reasoning.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
 
83
  chunks, meta = [], []
84
  for i, t in enumerate(texts):
85
  for j, c in enumerate(adaptive_chunk(t)):
86
+ chunks.append(c)
87
+ meta.append({**metadata[i], "source_index": i, "chunk_index": j})
88
  rows = [{"content": c, "embedding": e, "metadata": m, "namespace": namespace, "user_id": user_id}
89
  for c, e, m in zip(chunks, embed_texts(chunks), meta)]
90
  for i in range(0, len(rows), 50):
 
95
  cls = classify_query(query)
96
  level = cls["level"]
97
  k = DEPTH[level]
 
98
 
99
+ # L0: always retrieve first β€” give Gemma the context and let it decide
100
+ # whether to use it or answer from its own knowledge. This prevents
101
+ # hallucination when the answer exists in the user docs.
102
  if level == 0:
103
+ hits = vsearch(query, namespace, user_id, 2)
104
+ if hits:
105
+ ctx = "\n\n---\n\n".join(h["content"] for h in hits)
 
106
  r = client.models.generate_content(model=GEN_MODEL,
107
+ contents=(
108
+ f"Use the context below if it contains a relevant answer to the question. "
109
+ f"If the context is not relevant, answer from your own knowledge instead.\n\n"
110
+ f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:"
111
+ ))
112
+ top_score = hits[0].get("similarity", 0)
113
  return {
114
  "answer": r.text.strip(),
115
+ "sources": [{"content": h["content"][:200], "metadata": h.get("metadata",{}),
116
+ "score": h.get("similarity",0), "source": "local"}
117
+ for h in hits if h.get("similarity",0) > 0.5],
118
  "complexity": cls,
119
+ "retrieval_source": "local" if top_score > 0.5 else "model_knowledge",
120
  }
121
+ # No docs at all β€” answer from model knowledge
122
+ r = client.models.generate_content(model=GEN_MODEL,
123
+ contents=f"Answer this from your knowledge:\n\n{query}")
124
  return {"answer": r.text.strip(), "sources": [], "complexity": cls, "retrieval_source": "model_knowledge"}
125
 
126
+ # L1–L3: standard retrieval
127
  hits = []
128
  if level == 3:
129
  seen = set()
 
137
  if use_web or not hits:
138
  web_hits = web_search(query, max_results=k)
139
  if not hits and not web_hits:
140
+ return {"answer": "Nothing found locally or on the web.", "sources": [],
141
+ "complexity": cls, "retrieval_source": "none"}
142
  retrieval_source = "web" if not hits else "local_and_web"
143
 
144
  all_chunks, all_sources = [], []
 
146
  lc = [h["content"] for h in hits]
147
  if rerank and level >= 2: lc = [c for c in compress(query, lc) if c.strip()]
148
  all_chunks += lc[:k]
149
+ all_sources += [{"content": h["content"][:200], "metadata": h.get("metadata",{}),
150
+ "score": h.get("similarity",0), "source": "local"} for h in hits[:len(lc)]]
151
  if web_hits:
152
  all_chunks += [f"{h['title']}: {h['snippet']}" for h in web_hits]
153
  all_sources += [{"content": h["snippet"][:200], "metadata": {"title": h["title"], "url": h["url"]},
154
  "score": 1.0, "source": "web"} for h in web_hits]
155
 
156
  ctx = "\n\n---\n\n".join(all_chunks)
157
+ prompt = (WEB_PROMPTS if retrieval_source == "web" else PROMPTS).get(level, PROMPTS[1])
158
  r = client.models.generate_content(model=GEN_MODEL, contents=prompt.format(ctx=ctx, q=query))
159
+ return {"answer": r.text.strip(), "sources": all_sources,
160
+ "complexity": cls, "retrieval_source": retrieval_source}
161
 
162
  async def run_acra_pipeline(mode, **kw):
163
  if mode == "ingest":
164
  return await ingest_pipeline(kw["texts"], kw["metadata"], kw["namespace"], kw["user_id"])
165
+ return await query_pipeline(kw["query"], kw["namespace"], kw["top_k"],
166
+ kw["rerank"], kw["user_id"], use_web=kw.get("use_web", False))