fix: 503 retry + token tracking + cost calc per response
Browse files
acra.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import os, httpx
|
| 2 |
from google import genai
|
| 3 |
from google.genai import types
|
| 4 |
from db import supabase
|
|
@@ -10,6 +10,44 @@ EMBED_MODEL = "gemini-embedding-001"
|
|
| 10 |
GEN_MODEL = "gemini-3.1-flash-lite-preview"
|
| 11 |
DEPTH = {0: 3, 1: 3, 2: 6, 3: 10}
|
| 12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
def embed_texts(texts):
|
| 14 |
r = client.models.embed_content(model=EMBED_MODEL, contents=texts,
|
| 15 |
config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT", output_dimensionality=768))
|
|
@@ -33,10 +71,9 @@ def adaptive_chunk(text, max_tok=512):
|
|
| 33 |
return chunks or [text]
|
| 34 |
|
| 35 |
def web_search(query: str, max_results: int = 5) -> List[dict]:
|
| 36 |
-
"""Web search via Jina AI β requires JINA_API_KEY secret in Space."""
|
| 37 |
try:
|
| 38 |
import urllib.parse
|
| 39 |
-
encoded
|
| 40 |
jina_key = os.environ.get("JINA_API_KEY", "")
|
| 41 |
r = httpx.get(
|
| 42 |
f"https://s.jina.ai/?q={encoded}",
|
|
@@ -46,7 +83,7 @@ def web_search(query: str, max_results: int = 5) -> List[dict]:
|
|
| 46 |
"X-Retain-Images": "none",
|
| 47 |
"X-Engine": "direct",
|
| 48 |
},
|
| 49 |
-
timeout=
|
| 50 |
follow_redirects=True
|
| 51 |
)
|
| 52 |
if r.status_code != 200:
|
|
@@ -69,20 +106,18 @@ def web_search(query: str, max_results: int = 5) -> List[dict]:
|
|
| 69 |
return []
|
| 70 |
|
| 71 |
def decompose(query):
|
| 72 |
-
r =
|
| 73 |
-
contents=f"Decompose into 2-4 simpler sub-queries. Numbered list only.\n\nQuery: {query}")
|
| 74 |
lines = [l.strip().lstrip("1234567890.). ") for l in r.text.strip().split("\n") if l.strip()]
|
| 75 |
return lines[:4] or [query]
|
| 76 |
|
| 77 |
def compress(query, chunks):
|
| 78 |
if not chunks: return []
|
| 79 |
numbered = "\n\n".join(f"[{i+1}]\n{c}" for i, c in enumerate(chunks))
|
| 80 |
-
r =
|
| 81 |
f"You have {len(chunks)} text chunks and a query.\n"
|
| 82 |
f"For each chunk extract ONLY sentences relevant to the query.\n"
|
| 83 |
f"Reply as [1] <text or EMPTY> [2] <text or EMPTY> etc.\n\n"
|
| 84 |
-
f"Query: {query}\n\nChunks:\n{numbered}")
|
| 85 |
-
import re
|
| 86 |
out = []
|
| 87 |
for m in re.finditer(r"\[(\d+)\]\s*(.*?)(?=\[\d+\]|$)", r.text, re.DOTALL):
|
| 88 |
t = m.group(2).strip()
|
|
@@ -123,20 +158,22 @@ async def ingest_pipeline(texts, metadata, namespace, user_id):
|
|
| 123 |
return len(chunks)
|
| 124 |
|
| 125 |
async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False):
|
|
|
|
|
|
|
|
|
|
| 126 |
cls = classify_query(query)
|
| 127 |
level = cls["level"]
|
| 128 |
k = DEPTH[level]
|
| 129 |
|
| 130 |
-
# use_web=True: skip ALL local retrieval, pure Jina web search
|
| 131 |
if use_web:
|
| 132 |
hits = web_search(query, max_results=6)
|
| 133 |
if not hits:
|
| 134 |
return {"answer": "No web results found.", "sources": [],
|
| 135 |
-
"complexity": cls, "retrieval_source": "none"
|
|
|
|
| 136 |
ctx = "\n\n---\n\n".join(
|
| 137 |
f"Source: {h['title']}\nURL: {h['url']}\n{h['snippet']}" for h in hits)
|
| 138 |
-
r =
|
| 139 |
-
contents=WEB_PROMPT.format(ctx=ctx, q=query))
|
| 140 |
return {
|
| 141 |
"answer": r.text.strip(),
|
| 142 |
"sources": [{"content": h["snippet"][:200],
|
|
@@ -144,17 +181,17 @@ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False
|
|
| 144 |
"score": 1.0, "source": "web"} for h in hits],
|
| 145 |
"complexity": cls,
|
| 146 |
"retrieval_source": "web",
|
|
|
|
| 147 |
}
|
| 148 |
|
| 149 |
-
# L0: check docs first, fall back to model knowledge
|
| 150 |
if level == 0:
|
| 151 |
doc_hits = vsearch(query, namespace, user_id, 2)
|
| 152 |
if doc_hits:
|
| 153 |
ctx = "\n\n---\n\n".join(h["content"] for h in doc_hits)
|
| 154 |
-
r =
|
| 155 |
f"Use the context if it contains a relevant answer. "
|
| 156 |
f"Otherwise answer from your own knowledge.\n\n"
|
| 157 |
-
f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:")
|
| 158 |
top_score = doc_hits[0].get("similarity", 0)
|
| 159 |
return {
|
| 160 |
"answer": r.text.strip(),
|
|
@@ -165,13 +202,13 @@ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False
|
|
| 165 |
for h in doc_hits if h.get("similarity", 0) > 0.5],
|
| 166 |
"complexity": cls,
|
| 167 |
"retrieval_source": "local" if top_score > 0.5 else "model_knowledge",
|
|
|
|
| 168 |
}
|
| 169 |
-
r =
|
| 170 |
-
contents=f"Answer from your knowledge:\n\n{query}")
|
| 171 |
return {"answer": r.text.strip(), "sources": [],
|
| 172 |
-
"complexity": cls, "retrieval_source": "model_knowledge"
|
|
|
|
| 173 |
|
| 174 |
-
# L1-L3: local vector retrieval
|
| 175 |
hits = []
|
| 176 |
if level == 3:
|
| 177 |
seen = set()
|
|
@@ -185,28 +222,29 @@ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False
|
|
| 185 |
web_hits = web_search(query, max_results=k)
|
| 186 |
if not web_hits:
|
| 187 |
return {"answer": "Nothing found locally or on the web.",
|
| 188 |
-
"sources": [], "complexity": cls, "retrieval_source": "none"
|
|
|
|
| 189 |
ctx = "\n\n---\n\n".join(
|
| 190 |
f"Source: {h['title']}\n{h['snippet']}" for h in web_hits)
|
| 191 |
-
r =
|
| 192 |
-
contents=WEB_PROMPT.format(ctx=ctx, q=query))
|
| 193 |
return {"answer": r.text.strip(),
|
| 194 |
"sources": [{"content": h["snippet"][:200],
|
| 195 |
"metadata": {"title": h["title"], "url": h["url"]},
|
| 196 |
"score": 1.0, "source": "web"} for h in web_hits],
|
| 197 |
-
"complexity": cls, "retrieval_source": "web"
|
|
|
|
| 198 |
|
| 199 |
lc = [h["content"] for h in hits]
|
| 200 |
if rerank and level >= 2: lc = [c for c in compress(query, lc) if c.strip()]
|
| 201 |
ctx = "\n\n---\n\n".join(lc[:k])
|
| 202 |
-
r =
|
| 203 |
-
contents=PROMPTS[level].format(ctx=ctx, q=query))
|
| 204 |
return {
|
| 205 |
"answer": r.text.strip(),
|
| 206 |
"sources": [{"content": h["content"][:200], "metadata": h.get("metadata", {}),
|
| 207 |
"score": h.get("similarity", 0), "source": "local"}
|
| 208 |
for h in hits[:len(lc)]],
|
| 209 |
"complexity": cls, "retrieval_source": "local",
|
|
|
|
| 210 |
}
|
| 211 |
|
| 212 |
async def run_acra_pipeline(mode, **kw):
|
|
|
|
| 1 |
+
import os, httpx, time, re
|
| 2 |
from google import genai
|
| 3 |
from google.genai import types
|
| 4 |
from db import supabase
|
|
|
|
| 10 |
GEN_MODEL = "gemini-3.1-flash-lite-preview"
|
| 11 |
DEPTH = {0: 3, 1: 3, 2: 6, 3: 10}
|
| 12 |
|
| 13 |
+
# ββ Pricing ββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
+
PRICE_IN = 0.075 / 1_000_000
|
| 15 |
+
PRICE_OUT = 0.30 / 1_000_000
|
| 16 |
+
GPT4O_IN = 2.50 / 1_000_000
|
| 17 |
+
GPT4O_OUT = 10.00 / 1_000_000
|
| 18 |
+
|
| 19 |
+
_total_input_tokens = 0
|
| 20 |
+
_total_output_tokens = 0
|
| 21 |
+
|
| 22 |
+
def _generate(contents, retries=3):
|
| 23 |
+
global _total_input_tokens, _total_output_tokens
|
| 24 |
+
for attempt in range(retries):
|
| 25 |
+
try:
|
| 26 |
+
r = client.models.generate_content(model=GEN_MODEL, contents=contents)
|
| 27 |
+
if hasattr(r, "usage_metadata") and r.usage_metadata:
|
| 28 |
+
_total_input_tokens += r.usage_metadata.prompt_token_count or 0
|
| 29 |
+
_total_output_tokens += r.usage_metadata.candidates_token_count or 0
|
| 30 |
+
return r
|
| 31 |
+
except Exception as e:
|
| 32 |
+
if ("503" in str(e) or "UNAVAILABLE" in str(e)) and attempt < retries - 1:
|
| 33 |
+
wait = 2 ** attempt
|
| 34 |
+
print(f"Gemini 503 β retrying in {wait}s (attempt {attempt+1})")
|
| 35 |
+
time.sleep(wait)
|
| 36 |
+
else:
|
| 37 |
+
raise
|
| 38 |
+
|
| 39 |
+
def calc_cost(in_tok, out_tok):
|
| 40 |
+
acra_cost = in_tok * PRICE_IN + out_tok * PRICE_OUT
|
| 41 |
+
gpt4o_cost = in_tok * GPT4O_IN + out_tok * GPT4O_OUT
|
| 42 |
+
savings_pct = round((1 - acra_cost / gpt4o_cost) * 100, 1) if gpt4o_cost else 0
|
| 43 |
+
return {
|
| 44 |
+
"input_tokens": in_tok,
|
| 45 |
+
"output_tokens": out_tok,
|
| 46 |
+
"acra_cost_usd": round(acra_cost, 6),
|
| 47 |
+
"gpt4o_cost_usd": round(gpt4o_cost, 6),
|
| 48 |
+
"savings_pct": savings_pct,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
def embed_texts(texts):
|
| 52 |
r = client.models.embed_content(model=EMBED_MODEL, contents=texts,
|
| 53 |
config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT", output_dimensionality=768))
|
|
|
|
| 71 |
return chunks or [text]
|
| 72 |
|
| 73 |
def web_search(query: str, max_results: int = 5) -> List[dict]:
|
|
|
|
| 74 |
try:
|
| 75 |
import urllib.parse
|
| 76 |
+
encoded = urllib.parse.quote(query)
|
| 77 |
jina_key = os.environ.get("JINA_API_KEY", "")
|
| 78 |
r = httpx.get(
|
| 79 |
f"https://s.jina.ai/?q={encoded}",
|
|
|
|
| 83 |
"X-Retain-Images": "none",
|
| 84 |
"X-Engine": "direct",
|
| 85 |
},
|
| 86 |
+
timeout=20.0,
|
| 87 |
follow_redirects=True
|
| 88 |
)
|
| 89 |
if r.status_code != 200:
|
|
|
|
| 106 |
return []
|
| 107 |
|
| 108 |
def decompose(query):
|
| 109 |
+
r = _generate(f"Decompose into 2-4 simpler sub-queries. Numbered list only.\n\nQuery: {query}")
|
|
|
|
| 110 |
lines = [l.strip().lstrip("1234567890.). ") for l in r.text.strip().split("\n") if l.strip()]
|
| 111 |
return lines[:4] or [query]
|
| 112 |
|
| 113 |
def compress(query, chunks):
|
| 114 |
if not chunks: return []
|
| 115 |
numbered = "\n\n".join(f"[{i+1}]\n{c}" for i, c in enumerate(chunks))
|
| 116 |
+
r = _generate(
|
| 117 |
f"You have {len(chunks)} text chunks and a query.\n"
|
| 118 |
f"For each chunk extract ONLY sentences relevant to the query.\n"
|
| 119 |
f"Reply as [1] <text or EMPTY> [2] <text or EMPTY> etc.\n\n"
|
| 120 |
+
f"Query: {query}\n\nChunks:\n{numbered}")
|
|
|
|
| 121 |
out = []
|
| 122 |
for m in re.finditer(r"\[(\d+)\]\s*(.*?)(?=\[\d+\]|$)", r.text, re.DOTALL):
|
| 123 |
t = m.group(2).strip()
|
|
|
|
| 158 |
return len(chunks)
|
| 159 |
|
| 160 |
async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False):
|
| 161 |
+
global _total_input_tokens, _total_output_tokens
|
| 162 |
+
_total_input_tokens = _total_output_tokens = 0
|
| 163 |
+
|
| 164 |
cls = classify_query(query)
|
| 165 |
level = cls["level"]
|
| 166 |
k = DEPTH[level]
|
| 167 |
|
|
|
|
| 168 |
if use_web:
|
| 169 |
hits = web_search(query, max_results=6)
|
| 170 |
if not hits:
|
| 171 |
return {"answer": "No web results found.", "sources": [],
|
| 172 |
+
"complexity": cls, "retrieval_source": "none",
|
| 173 |
+
"cost": calc_cost(0, 0)}
|
| 174 |
ctx = "\n\n---\n\n".join(
|
| 175 |
f"Source: {h['title']}\nURL: {h['url']}\n{h['snippet']}" for h in hits)
|
| 176 |
+
r = _generate(WEB_PROMPT.format(ctx=ctx, q=query))
|
|
|
|
| 177 |
return {
|
| 178 |
"answer": r.text.strip(),
|
| 179 |
"sources": [{"content": h["snippet"][:200],
|
|
|
|
| 181 |
"score": 1.0, "source": "web"} for h in hits],
|
| 182 |
"complexity": cls,
|
| 183 |
"retrieval_source": "web",
|
| 184 |
+
"cost": calc_cost(_total_input_tokens, _total_output_tokens),
|
| 185 |
}
|
| 186 |
|
|
|
|
| 187 |
if level == 0:
|
| 188 |
doc_hits = vsearch(query, namespace, user_id, 2)
|
| 189 |
if doc_hits:
|
| 190 |
ctx = "\n\n---\n\n".join(h["content"] for h in doc_hits)
|
| 191 |
+
r = _generate(
|
| 192 |
f"Use the context if it contains a relevant answer. "
|
| 193 |
f"Otherwise answer from your own knowledge.\n\n"
|
| 194 |
+
f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:")
|
| 195 |
top_score = doc_hits[0].get("similarity", 0)
|
| 196 |
return {
|
| 197 |
"answer": r.text.strip(),
|
|
|
|
| 202 |
for h in doc_hits if h.get("similarity", 0) > 0.5],
|
| 203 |
"complexity": cls,
|
| 204 |
"retrieval_source": "local" if top_score > 0.5 else "model_knowledge",
|
| 205 |
+
"cost": calc_cost(_total_input_tokens, _total_output_tokens),
|
| 206 |
}
|
| 207 |
+
r = _generate(f"Answer from your knowledge:\n\n{query}")
|
|
|
|
| 208 |
return {"answer": r.text.strip(), "sources": [],
|
| 209 |
+
"complexity": cls, "retrieval_source": "model_knowledge",
|
| 210 |
+
"cost": calc_cost(_total_input_tokens, _total_output_tokens)}
|
| 211 |
|
|
|
|
| 212 |
hits = []
|
| 213 |
if level == 3:
|
| 214 |
seen = set()
|
|
|
|
| 222 |
web_hits = web_search(query, max_results=k)
|
| 223 |
if not web_hits:
|
| 224 |
return {"answer": "Nothing found locally or on the web.",
|
| 225 |
+
"sources": [], "complexity": cls, "retrieval_source": "none",
|
| 226 |
+
"cost": calc_cost(_total_input_tokens, _total_output_tokens)}
|
| 227 |
ctx = "\n\n---\n\n".join(
|
| 228 |
f"Source: {h['title']}\n{h['snippet']}" for h in web_hits)
|
| 229 |
+
r = _generate(WEB_PROMPT.format(ctx=ctx, q=query))
|
|
|
|
| 230 |
return {"answer": r.text.strip(),
|
| 231 |
"sources": [{"content": h["snippet"][:200],
|
| 232 |
"metadata": {"title": h["title"], "url": h["url"]},
|
| 233 |
"score": 1.0, "source": "web"} for h in web_hits],
|
| 234 |
+
"complexity": cls, "retrieval_source": "web",
|
| 235 |
+
"cost": calc_cost(_total_input_tokens, _total_output_tokens)}
|
| 236 |
|
| 237 |
lc = [h["content"] for h in hits]
|
| 238 |
if rerank and level >= 2: lc = [c for c in compress(query, lc) if c.strip()]
|
| 239 |
ctx = "\n\n---\n\n".join(lc[:k])
|
| 240 |
+
r = _generate(PROMPTS[level].format(ctx=ctx, q=query))
|
|
|
|
| 241 |
return {
|
| 242 |
"answer": r.text.strip(),
|
| 243 |
"sources": [{"content": h["content"][:200], "metadata": h.get("metadata", {}),
|
| 244 |
"score": h.get("similarity", 0), "source": "local"}
|
| 245 |
for h in hits[:len(lc)]],
|
| 246 |
"complexity": cls, "retrieval_source": "local",
|
| 247 |
+
"cost": calc_cost(_total_input_tokens, _total_output_tokens),
|
| 248 |
}
|
| 249 |
|
| 250 |
async def run_acra_pipeline(mode, **kw):
|