Nottybro commited on
Commit
9d77494
·
verified ·
1 Parent(s): 45cb177

fix: L0 checks docs first (similarity>0.75) before model fallback

Browse files
Files changed (1) hide show
  1. acra.py +41 -26
acra.py CHANGED
@@ -5,31 +5,21 @@ from db import supabase
5
  from classifier_inference import classify_query
6
  from typing import List
7
 
8
- client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
9
- EMBED_MODEL = "gemini-embedding-001" # replaces shut-down text-embedding-004
10
  GEN_MODEL = "gemma-3-27b-it"
11
- DEPTH = {0: 0, 1: 3, 2: 6, 3: 10}
12
 
13
  def embed_texts(texts):
14
  result = client.models.embed_content(
15
- model=EMBED_MODEL,
16
- contents=texts,
17
- config=types.EmbedContentConfig(
18
- task_type="RETRIEVAL_DOCUMENT",
19
- output_dimensionality=768 # keeps existing Supabase vector(768) schema
20
- )
21
- )
22
  return [e.values for e in result.embeddings]
23
 
24
  def embed_query(q):
25
  result = client.models.embed_content(
26
- model=EMBED_MODEL,
27
- contents=[q],
28
- config=types.EmbedContentConfig(
29
- task_type="RETRIEVAL_QUERY",
30
- output_dimensionality=768
31
- )
32
- )
33
  return result.embeddings[0].values
34
 
35
  def adaptive_chunk(text, max_tok=512):
@@ -81,7 +71,7 @@ def vsearch(query, namespace, user_id, k):
81
  }).execute().data or [])
82
 
83
  PROMPTS = {
84
- 0: "Answer from your knowledge:\n\n{q}",
85
  1: "Answer using ONLY the context. Be concise.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
86
  2: "Synthesize the context step by step.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
87
  3: "Use chain-of-thought reasoning.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
@@ -104,10 +94,30 @@ async def ingest_pipeline(texts, metadata, namespace, user_id):
104
  return len(chunks)
105
 
106
  async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False):
107
- cls = classify_query(query); level = cls["level"]; k = DEPTH[level]
 
 
 
 
 
108
  if level == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  r = client.models.generate_content(model=GEN_MODEL, contents=PROMPTS[0].format(q=query))
110
  return {"answer": r.text.strip(), "sources": [], "complexity": cls, "retrieval_source": "model_knowledge"}
 
 
111
  hits = []
112
  if level == 3:
113
  seen = set()
@@ -116,24 +126,29 @@ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False
116
  if h["id"] not in seen: seen.add(h["id"]); hits.append(h)
117
  else:
118
  hits = vsearch(query, namespace, user_id, k)
 
119
  web_hits, retrieval_source = [], "local"
120
  if use_web or not hits:
121
  web_hits = web_search(query, max_results=k)
122
  if not hits and not web_hits:
123
  return {"answer": "Nothing found locally or on the web.", "sources": [], "complexity": cls, "retrieval_source": "none"}
124
  retrieval_source = "web" if not hits else "local_and_web"
 
125
  all_chunks, all_sources = [], []
126
  if hits:
127
  lc = [h["content"] for h in hits]
128
  if rerank and level >= 2: lc = [c for c in compress(query, lc) if c.strip()]
129
- all_chunks += lc[:k]
130
- all_sources += [{"content": h["content"][:200], "metadata": h.get("metadata",{}), "score": h.get("similarity",0), "source": "local"} for h in hits[:len(lc)]]
 
131
  if web_hits:
132
- all_chunks += [f"{h['title']}: {h['snippet']}" for h in web_hits]
133
- all_sources += [{"content": h["snippet"][:200], "metadata": {"title": h["title"], "url": h["url"]}, "score": 1.0, "source": "web"} for h in web_hits]
134
- ctx = "\n\n---\n\n".join(all_chunks)
135
- prompt = (WEB_PROMPTS if retrieval_source=="web" else PROMPTS).get(level, PROMPTS[level])
136
- r = client.models.generate_content(model=GEN_MODEL, contents=prompt.format(ctx=ctx, q=query))
 
 
137
  return {"answer": r.text.strip(), "sources": all_sources, "complexity": cls, "retrieval_source": retrieval_source}
138
 
139
  async def run_acra_pipeline(mode, **kw):
 
5
  from classifier_inference import classify_query
6
  from typing import List
7
 
8
+ client = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
9
+ EMBED_MODEL = "gemini-embedding-001"
10
  GEN_MODEL = "gemma-3-27b-it"
11
+ DEPTH = {0: 3, 1: 3, 2: 6, 3: 10}
12
 
13
  def embed_texts(texts):
14
  result = client.models.embed_content(
15
+ model=EMBED_MODEL, contents=texts,
16
+ config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT", output_dimensionality=768))
 
 
 
 
 
17
  return [e.values for e in result.embeddings]
18
 
19
  def embed_query(q):
20
  result = client.models.embed_content(
21
+ model=EMBED_MODEL, contents=[q],
22
+ config=types.EmbedContentConfig(task_type="RETRIEVAL_QUERY", output_dimensionality=768))
 
 
 
 
 
23
  return result.embeddings[0].values
24
 
25
  def adaptive_chunk(text, max_tok=512):
 
71
  }).execute().data or [])
72
 
73
  PROMPTS = {
74
+ 0: "Answer this from your knowledge:\n\n{q}",
75
  1: "Answer using ONLY the context. Be concise.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
76
  2: "Synthesize the context step by step.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
77
  3: "Use chain-of-thought reasoning.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
 
94
  return len(chunks)
95
 
96
  async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False):
97
+ cls = classify_query(query)
98
+ level = cls["level"]
99
+ k = DEPTH[level]
100
+ model = client
101
+
102
+ # ── L0: try docs first (similarity > 0.75), fall back to model knowledge
103
  if level == 0:
104
+ l0_hits = vsearch(query, namespace, user_id, 3)
105
+ strong_hits = [h for h in l0_hits if h.get("similarity", 0) > 0.75]
106
+ if strong_hits:
107
+ ctx = "\n\n---\n\n".join(h["content"] for h in strong_hits)
108
+ r = client.models.generate_content(model=GEN_MODEL,
109
+ contents=f"Answer using ONLY the context. Be concise.\n\nContext:\n{ctx}\n\nQuestion: {query}\nAnswer:")
110
+ return {
111
+ "answer": r.text.strip(),
112
+ "sources": [{"content": h["content"][:200], "metadata": h.get("metadata", {}),
113
+ "score": h.get("similarity", 0), "source": "local"} for h in strong_hits],
114
+ "complexity": cls,
115
+ "retrieval_source": "local"
116
+ }
117
  r = client.models.generate_content(model=GEN_MODEL, contents=PROMPTS[0].format(q=query))
118
  return {"answer": r.text.strip(), "sources": [], "complexity": cls, "retrieval_source": "model_knowledge"}
119
+
120
+ # ── L1-L3: standard retrieval
121
  hits = []
122
  if level == 3:
123
  seen = set()
 
126
  if h["id"] not in seen: seen.add(h["id"]); hits.append(h)
127
  else:
128
  hits = vsearch(query, namespace, user_id, k)
129
+
130
  web_hits, retrieval_source = [], "local"
131
  if use_web or not hits:
132
  web_hits = web_search(query, max_results=k)
133
  if not hits and not web_hits:
134
  return {"answer": "Nothing found locally or on the web.", "sources": [], "complexity": cls, "retrieval_source": "none"}
135
  retrieval_source = "web" if not hits else "local_and_web"
136
+
137
  all_chunks, all_sources = [], []
138
  if hits:
139
  lc = [h["content"] for h in hits]
140
  if rerank and level >= 2: lc = [c for c in compress(query, lc) if c.strip()]
141
+ all_chunks += lc[:k]
142
+ all_sources += [{"content": h["content"][:200], "metadata": h.get("metadata", {}),
143
+ "score": h.get("similarity", 0), "source": "local"} for h in hits[:len(lc)]]
144
  if web_hits:
145
+ all_chunks += [f"{h['title']}: {h['snippet']}" for h in web_hits]
146
+ all_sources += [{"content": h["snippet"][:200], "metadata": {"title": h["title"], "url": h["url"]},
147
+ "score": 1.0, "source": "web"} for h in web_hits]
148
+
149
+ ctx = "\n\n---\n\n".join(all_chunks)
150
+ prompt = (WEB_PROMPTS if retrieval_source == "web" else PROMPTS).get(level, PROMPTS[level])
151
+ r = client.models.generate_content(model=GEN_MODEL, contents=prompt.format(ctx=ctx, q=query))
152
  return {"answer": r.text.strip(), "sources": all_sources, "complexity": cls, "retrieval_source": retrieval_source}
153
 
154
  async def run_acra_pipeline(mode, **kw):