Nottybro commited on
Commit
b0344af
Β·
verified Β·
1 Parent(s): df5b006

fix: 503 retry + token tracking + cost calc per response

Browse files
Files changed (1) hide show
  1. acra.py +64 -26
acra.py CHANGED
@@ -1,4 +1,4 @@
1
- import os, httpx
2
  from google import genai
3
  from google.genai import types
4
  from db import supabase
@@ -10,6 +10,44 @@ EMBED_MODEL = "gemini-embedding-001"
10
  GEN_MODEL = "gemini-3.1-flash-lite-preview"
11
  DEPTH = {0: 3, 1: 3, 2: 6, 3: 10}
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  def embed_texts(texts):
14
  r = client.models.embed_content(model=EMBED_MODEL, contents=texts,
15
  config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT", output_dimensionality=768))
@@ -33,10 +71,9 @@ def adaptive_chunk(text, max_tok=512):
33
  return chunks or [text]
34
 
35
  def web_search(query: str, max_results: int = 5) -> List[dict]:
36
- """Web search via Jina AI β€” requires JINA_API_KEY secret in Space."""
37
  try:
38
  import urllib.parse
39
- encoded = urllib.parse.quote(query)
40
  jina_key = os.environ.get("JINA_API_KEY", "")
41
  r = httpx.get(
42
  f"https://s.jina.ai/?q={encoded}",
@@ -46,7 +83,7 @@ def web_search(query: str, max_results: int = 5) -> List[dict]:
46
  "X-Retain-Images": "none",
47
  "X-Engine": "direct",
48
  },
49
- timeout=30.0,
50
  follow_redirects=True
51
  )
52
  if r.status_code != 200:
@@ -69,20 +106,18 @@ def web_search(query: str, max_results: int = 5) -> List[dict]:
69
  return []
70
 
71
  def decompose(query):
72
- r = client.models.generate_content(model=GEN_MODEL,
73
- contents=f"Decompose into 2-4 simpler sub-queries. Numbered list only.\n\nQuery: {query}")
74
  lines = [l.strip().lstrip("1234567890.). ") for l in r.text.strip().split("\n") if l.strip()]
75
  return lines[:4] or [query]
76
 
77
  def compress(query, chunks):
78
  if not chunks: return []
79
  numbered = "\n\n".join(f"[{i+1}]\n{c}" for i, c in enumerate(chunks))
80
- r = client.models.generate_content(model=GEN_MODEL, contents=(
81
  f"You have {len(chunks)} text chunks and a query.\n"
82
  f"For each chunk extract ONLY sentences relevant to the query.\n"
83
  f"Reply as [1] <text or EMPTY> [2] <text or EMPTY> etc.\n\n"
84
- f"Query: {query}\n\nChunks:\n{numbered}"))
85
- import re
86
  out = []
87
  for m in re.finditer(r"\[(\d+)\]\s*(.*?)(?=\[\d+\]|$)", r.text, re.DOTALL):
88
  t = m.group(2).strip()
@@ -123,20 +158,22 @@ async def ingest_pipeline(texts, metadata, namespace, user_id):
123
  return len(chunks)
124
 
125
  async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False):
 
 
 
126
  cls = classify_query(query)
127
  level = cls["level"]
128
  k = DEPTH[level]
129
 
130
- # use_web=True: skip ALL local retrieval, pure Jina web search
131
  if use_web:
132
  hits = web_search(query, max_results=6)
133
  if not hits:
134
  return {"answer": "No web results found.", "sources": [],
135
- "complexity": cls, "retrieval_source": "none"}
 
136
  ctx = "\n\n---\n\n".join(
137
  f"Source: {h['title']}\nURL: {h['url']}\n{h['snippet']}" for h in hits)
138
- r = client.models.generate_content(model=GEN_MODEL,
139
- contents=WEB_PROMPT.format(ctx=ctx, q=query))
140
  return {
141
  "answer": r.text.strip(),
142
  "sources": [{"content": h["snippet"][:200],
@@ -144,17 +181,17 @@ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False
144
  "score": 1.0, "source": "web"} for h in hits],
145
  "complexity": cls,
146
  "retrieval_source": "web",
 
147
  }
148
 
149
- # L0: check docs first, fall back to model knowledge
150
  if level == 0:
151
  doc_hits = vsearch(query, namespace, user_id, 2)
152
  if doc_hits:
153
  ctx = "\n\n---\n\n".join(h["content"] for h in doc_hits)
154
- r = client.models.generate_content(model=GEN_MODEL, contents=(
155
  f"Use the context if it contains a relevant answer. "
156
  f"Otherwise answer from your own knowledge.\n\n"
157
- f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:"))
158
  top_score = doc_hits[0].get("similarity", 0)
159
  return {
160
  "answer": r.text.strip(),
@@ -165,13 +202,13 @@ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False
165
  for h in doc_hits if h.get("similarity", 0) > 0.5],
166
  "complexity": cls,
167
  "retrieval_source": "local" if top_score > 0.5 else "model_knowledge",
 
168
  }
169
- r = client.models.generate_content(model=GEN_MODEL,
170
- contents=f"Answer from your knowledge:\n\n{query}")
171
  return {"answer": r.text.strip(), "sources": [],
172
- "complexity": cls, "retrieval_source": "model_knowledge"}
 
173
 
174
- # L1-L3: local vector retrieval
175
  hits = []
176
  if level == 3:
177
  seen = set()
@@ -185,28 +222,29 @@ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False
185
  web_hits = web_search(query, max_results=k)
186
  if not web_hits:
187
  return {"answer": "Nothing found locally or on the web.",
188
- "sources": [], "complexity": cls, "retrieval_source": "none"}
 
189
  ctx = "\n\n---\n\n".join(
190
  f"Source: {h['title']}\n{h['snippet']}" for h in web_hits)
191
- r = client.models.generate_content(model=GEN_MODEL,
192
- contents=WEB_PROMPT.format(ctx=ctx, q=query))
193
  return {"answer": r.text.strip(),
194
  "sources": [{"content": h["snippet"][:200],
195
  "metadata": {"title": h["title"], "url": h["url"]},
196
  "score": 1.0, "source": "web"} for h in web_hits],
197
- "complexity": cls, "retrieval_source": "web"}
 
198
 
199
  lc = [h["content"] for h in hits]
200
  if rerank and level >= 2: lc = [c for c in compress(query, lc) if c.strip()]
201
  ctx = "\n\n---\n\n".join(lc[:k])
202
- r = client.models.generate_content(model=GEN_MODEL,
203
- contents=PROMPTS[level].format(ctx=ctx, q=query))
204
  return {
205
  "answer": r.text.strip(),
206
  "sources": [{"content": h["content"][:200], "metadata": h.get("metadata", {}),
207
  "score": h.get("similarity", 0), "source": "local"}
208
  for h in hits[:len(lc)]],
209
  "complexity": cls, "retrieval_source": "local",
 
210
  }
211
 
212
  async def run_acra_pipeline(mode, **kw):
 
1
+ import os, httpx, time, re
2
  from google import genai
3
  from google.genai import types
4
  from db import supabase
 
10
  GEN_MODEL = "gemini-3.1-flash-lite-preview"
11
  DEPTH = {0: 3, 1: 3, 2: 6, 3: 10}
12
 
13
+ # ── Pricing ────────────────────────────────────────────────────
14
+ PRICE_IN = 0.075 / 1_000_000
15
+ PRICE_OUT = 0.30 / 1_000_000
16
+ GPT4O_IN = 2.50 / 1_000_000
17
+ GPT4O_OUT = 10.00 / 1_000_000
18
+
19
+ _total_input_tokens = 0
20
+ _total_output_tokens = 0
21
+
22
+ def _generate(contents, retries=3):
23
+ global _total_input_tokens, _total_output_tokens
24
+ for attempt in range(retries):
25
+ try:
26
+ r = client.models.generate_content(model=GEN_MODEL, contents=contents)
27
+ if hasattr(r, "usage_metadata") and r.usage_metadata:
28
+ _total_input_tokens += r.usage_metadata.prompt_token_count or 0
29
+ _total_output_tokens += r.usage_metadata.candidates_token_count or 0
30
+ return r
31
+ except Exception as e:
32
+ if ("503" in str(e) or "UNAVAILABLE" in str(e)) and attempt < retries - 1:
33
+ wait = 2 ** attempt
34
+ print(f"Gemini 503 β€” retrying in {wait}s (attempt {attempt+1})")
35
+ time.sleep(wait)
36
+ else:
37
+ raise
38
+
39
+ def calc_cost(in_tok, out_tok):
40
+ acra_cost = in_tok * PRICE_IN + out_tok * PRICE_OUT
41
+ gpt4o_cost = in_tok * GPT4O_IN + out_tok * GPT4O_OUT
42
+ savings_pct = round((1 - acra_cost / gpt4o_cost) * 100, 1) if gpt4o_cost else 0
43
+ return {
44
+ "input_tokens": in_tok,
45
+ "output_tokens": out_tok,
46
+ "acra_cost_usd": round(acra_cost, 6),
47
+ "gpt4o_cost_usd": round(gpt4o_cost, 6),
48
+ "savings_pct": savings_pct,
49
+ }
50
+
51
  def embed_texts(texts):
52
  r = client.models.embed_content(model=EMBED_MODEL, contents=texts,
53
  config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT", output_dimensionality=768))
 
71
  return chunks or [text]
72
 
73
  def web_search(query: str, max_results: int = 5) -> List[dict]:
 
74
  try:
75
  import urllib.parse
76
+ encoded = urllib.parse.quote(query)
77
  jina_key = os.environ.get("JINA_API_KEY", "")
78
  r = httpx.get(
79
  f"https://s.jina.ai/?q={encoded}",
 
83
  "X-Retain-Images": "none",
84
  "X-Engine": "direct",
85
  },
86
+ timeout=20.0,
87
  follow_redirects=True
88
  )
89
  if r.status_code != 200:
 
106
  return []
107
 
108
  def decompose(query):
109
+ r = _generate(f"Decompose into 2-4 simpler sub-queries. Numbered list only.\n\nQuery: {query}")
 
110
  lines = [l.strip().lstrip("1234567890.). ") for l in r.text.strip().split("\n") if l.strip()]
111
  return lines[:4] or [query]
112
 
113
  def compress(query, chunks):
114
  if not chunks: return []
115
  numbered = "\n\n".join(f"[{i+1}]\n{c}" for i, c in enumerate(chunks))
116
+ r = _generate(
117
  f"You have {len(chunks)} text chunks and a query.\n"
118
  f"For each chunk extract ONLY sentences relevant to the query.\n"
119
  f"Reply as [1] <text or EMPTY> [2] <text or EMPTY> etc.\n\n"
120
+ f"Query: {query}\n\nChunks:\n{numbered}")
 
121
  out = []
122
  for m in re.finditer(r"\[(\d+)\]\s*(.*?)(?=\[\d+\]|$)", r.text, re.DOTALL):
123
  t = m.group(2).strip()
 
158
  return len(chunks)
159
 
160
  async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False):
161
+ global _total_input_tokens, _total_output_tokens
162
+ _total_input_tokens = _total_output_tokens = 0
163
+
164
  cls = classify_query(query)
165
  level = cls["level"]
166
  k = DEPTH[level]
167
 
 
168
  if use_web:
169
  hits = web_search(query, max_results=6)
170
  if not hits:
171
  return {"answer": "No web results found.", "sources": [],
172
+ "complexity": cls, "retrieval_source": "none",
173
+ "cost": calc_cost(0, 0)}
174
  ctx = "\n\n---\n\n".join(
175
  f"Source: {h['title']}\nURL: {h['url']}\n{h['snippet']}" for h in hits)
176
+ r = _generate(WEB_PROMPT.format(ctx=ctx, q=query))
 
177
  return {
178
  "answer": r.text.strip(),
179
  "sources": [{"content": h["snippet"][:200],
 
181
  "score": 1.0, "source": "web"} for h in hits],
182
  "complexity": cls,
183
  "retrieval_source": "web",
184
+ "cost": calc_cost(_total_input_tokens, _total_output_tokens),
185
  }
186
 
 
187
  if level == 0:
188
  doc_hits = vsearch(query, namespace, user_id, 2)
189
  if doc_hits:
190
  ctx = "\n\n---\n\n".join(h["content"] for h in doc_hits)
191
+ r = _generate(
192
  f"Use the context if it contains a relevant answer. "
193
  f"Otherwise answer from your own knowledge.\n\n"
194
+ f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:")
195
  top_score = doc_hits[0].get("similarity", 0)
196
  return {
197
  "answer": r.text.strip(),
 
202
  for h in doc_hits if h.get("similarity", 0) > 0.5],
203
  "complexity": cls,
204
  "retrieval_source": "local" if top_score > 0.5 else "model_knowledge",
205
+ "cost": calc_cost(_total_input_tokens, _total_output_tokens),
206
  }
207
+ r = _generate(f"Answer from your knowledge:\n\n{query}")
 
208
  return {"answer": r.text.strip(), "sources": [],
209
+ "complexity": cls, "retrieval_source": "model_knowledge",
210
+ "cost": calc_cost(_total_input_tokens, _total_output_tokens)}
211
 
 
212
  hits = []
213
  if level == 3:
214
  seen = set()
 
222
  web_hits = web_search(query, max_results=k)
223
  if not web_hits:
224
  return {"answer": "Nothing found locally or on the web.",
225
+ "sources": [], "complexity": cls, "retrieval_source": "none",
226
+ "cost": calc_cost(_total_input_tokens, _total_output_tokens)}
227
  ctx = "\n\n---\n\n".join(
228
  f"Source: {h['title']}\n{h['snippet']}" for h in web_hits)
229
+ r = _generate(WEB_PROMPT.format(ctx=ctx, q=query))
 
230
  return {"answer": r.text.strip(),
231
  "sources": [{"content": h["snippet"][:200],
232
  "metadata": {"title": h["title"], "url": h["url"]},
233
  "score": 1.0, "source": "web"} for h in web_hits],
234
+ "complexity": cls, "retrieval_source": "web",
235
+ "cost": calc_cost(_total_input_tokens, _total_output_tokens)}
236
 
237
  lc = [h["content"] for h in hits]
238
  if rerank and level >= 2: lc = [c for c in compress(query, lc) if c.strip()]
239
  ctx = "\n\n---\n\n".join(lc[:k])
240
+ r = _generate(PROMPTS[level].format(ctx=ctx, q=query))
 
241
  return {
242
  "answer": r.text.strip(),
243
  "sources": [{"content": h["content"][:200], "metadata": h.get("metadata", {}),
244
  "score": h.get("similarity", 0), "source": "local"}
245
  for h in hits[:len(lc)]],
246
  "complexity": cls, "retrieval_source": "local",
247
+ "cost": calc_cost(_total_input_tokens, _total_output_tokens),
248
  }
249
 
250
  async def run_acra_pipeline(mode, **kw):