Nottybro commited on
Commit
fd390a0
Β·
verified Β·
1 Parent(s): c837482

fix: robust token tracking, retry, Jina fallback, no more 500s

Browse files
Files changed (1) hide show
  1. acra.py +122 -104
acra.py CHANGED
@@ -10,42 +10,60 @@ EMBED_MODEL = "gemini-embedding-001"
10
  GEN_MODEL = "gemini-3.1-flash-lite-preview"
11
  DEPTH = {0: 3, 1: 3, 2: 6, 3: 10}
12
 
13
- # ── Pricing ────────────────────────────────────────────────────
14
- PRICE_IN = 0.075 / 1_000_000
15
- PRICE_OUT = 0.30 / 1_000_000
16
- GPT4O_IN = 2.50 / 1_000_000
17
- GPT4O_OUT = 10.00 / 1_000_000
18
 
19
  _total_input_tokens = 0
20
  _total_output_tokens = 0
21
 
22
- def _generate(contents, retries=3):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  global _total_input_tokens, _total_output_tokens
 
24
  for attempt in range(retries):
25
  try:
26
  r = client.models.generate_content(model=GEN_MODEL, contents=contents)
27
- if hasattr(r, "usage_metadata") and r.usage_metadata:
28
- _total_input_tokens += r.usage_metadata.prompt_token_count or 0
29
- _total_output_tokens += r.usage_metadata.candidates_token_count or 0
30
  return r
31
  except Exception as e:
32
- if ("503" in str(e) or "UNAVAILABLE" in str(e)) and attempt < retries - 1:
33
- wait = 2 ** attempt
34
- print(f"Gemini 503 β€” retrying in {wait}s (attempt {attempt+1})")
 
 
35
  time.sleep(wait)
36
  else:
37
  raise
 
38
 
39
  def calc_cost(in_tok, out_tok):
40
- acra_cost = in_tok * PRICE_IN + out_tok * PRICE_OUT
41
- gpt4o_cost = in_tok * GPT4O_IN + out_tok * GPT4O_OUT
42
- savings_pct = round((1 - acra_cost / gpt4o_cost) * 100, 1) if gpt4o_cost else 0
43
  return {
44
  "input_tokens": in_tok,
45
  "output_tokens": out_tok,
46
- "acra_cost_usd": round(acra_cost, 6),
47
- "gpt4o_cost_usd": round(gpt4o_cost, 6),
48
- "savings_pct": savings_pct,
49
  }
50
 
51
  def embed_texts(texts):
@@ -71,39 +89,43 @@ def adaptive_chunk(text, max_tok=512):
71
  return chunks or [text]
72
 
73
  def web_search(query: str, max_results: int = 5) -> List[dict]:
74
- try:
75
- import urllib.parse
76
- encoded = urllib.parse.quote(query)
77
- jina_key = os.environ.get("JINA_API_KEY", "")
78
- r = httpx.get(
79
- f"https://s.jina.ai/?q={encoded}",
80
- headers={
81
- "Authorization": f"Bearer {jina_key}",
82
- "Accept": "application/json",
83
- "X-Retain-Images": "none",
84
- "X-Engine": "direct",
85
- },
86
- timeout=20.0,
87
- follow_redirects=True
88
- )
89
- if r.status_code != 200:
90
- print(f"Jina returned {r.status_code}: {r.text[:200]}")
91
- return []
92
- data = r.json()
93
- items = data.get("data", [])
94
- out = []
95
- for item in items[:max_results]:
96
- snippet = item.get("description") or item.get("content", "")
97
- if snippet:
98
- out.append({
99
- "title": item.get("title", ""),
100
- "snippet": snippet[:500],
101
- "url": item.get("url", "")
102
- })
103
- return out
104
- except Exception as e:
105
- print(f"Web search error: {e}")
106
- return []
 
 
 
 
107
 
108
  def decompose(query):
109
  r = _generate(f"Decompose into 2-4 simpler sub-queries. Numbered list only.\n\nQuery: {query}")
@@ -165,87 +187,83 @@ async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False
165
  level = cls["level"]
166
  k = DEPTH[level]
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  if use_web:
169
  hits = web_search(query, max_results=6)
170
  if not hits:
171
- return {"answer": "No web results found.", "sources": [],
172
- "complexity": cls, "retrieval_source": "none",
173
- "cost": calc_cost(0, 0)}
 
 
174
  ctx = "\n\n---\n\n".join(
175
  f"Source: {h['title']}\nURL: {h['url']}\n{h['snippet']}" for h in hits)
176
  r = _generate(WEB_PROMPT.format(ctx=ctx, q=query))
177
- return {
178
- "answer": r.text.strip(),
179
- "sources": [{"content": h["snippet"][:200],
180
- "metadata": {"title": h["title"], "url": h["url"]},
181
- "score": 1.0, "source": "web"} for h in hits],
182
- "complexity": cls,
183
- "retrieval_source": "web",
184
- "cost": calc_cost(_total_input_tokens, _total_output_tokens),
185
- }
186
 
 
187
  if level == 0:
188
  doc_hits = vsearch(query, namespace, user_id, 2)
189
  if doc_hits:
190
- ctx = "\n\n---\n\n".join(h["content"] for h in doc_hits)
191
- r = _generate(
192
- f"Use the context if it contains a relevant answer. "
193
- f"Otherwise answer from your own knowledge.\n\n"
194
- f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:")
195
  top_score = doc_hits[0].get("similarity", 0)
196
- return {
197
- "answer": r.text.strip(),
198
- "sources": [{"content": h["content"][:200],
199
- "metadata": h.get("metadata", {}),
200
- "score": h.get("similarity", 0),
201
- "source": "local"}
202
- for h in doc_hits if h.get("similarity", 0) > 0.5],
203
- "complexity": cls,
204
- "retrieval_source": "local" if top_score > 0.5 else "model_knowledge",
205
- "cost": calc_cost(_total_input_tokens, _total_output_tokens),
206
- }
207
  r = _generate(f"Answer from your knowledge:\n\n{query}")
208
  return {"answer": r.text.strip(), "sources": [],
209
- "complexity": cls, "retrieval_source": "model_knowledge",
210
- "cost": calc_cost(_total_input_tokens, _total_output_tokens)}
211
 
 
212
  hits = []
213
  if level == 3:
214
  seen = set()
215
  for sq in decompose(query):
216
  for h in vsearch(sq, namespace, user_id, 4):
217
- if h["id"] not in seen: seen.add(h["id"]); hits.append(h)
 
218
  else:
219
  hits = vsearch(query, namespace, user_id, k)
220
 
 
221
  if not hits:
222
  web_hits = web_search(query, max_results=k)
223
- if not web_hits:
224
- return {"answer": "Nothing found locally or on the web.",
225
- "sources": [], "complexity": cls, "retrieval_source": "none",
226
- "cost": calc_cost(_total_input_tokens, _total_output_tokens)}
227
- ctx = "\n\n---\n\n".join(
228
- f"Source: {h['title']}\n{h['snippet']}" for h in web_hits)
229
- r = _generate(WEB_PROMPT.format(ctx=ctx, q=query))
230
- return {"answer": r.text.strip(),
231
- "sources": [{"content": h["snippet"][:200],
232
- "metadata": {"title": h["title"], "url": h["url"]},
233
- "score": 1.0, "source": "web"} for h in web_hits],
234
- "complexity": cls, "retrieval_source": "web",
235
- "cost": calc_cost(_total_input_tokens, _total_output_tokens)}
236
 
237
  lc = [h["content"] for h in hits]
238
- if rerank and level >= 2: lc = [c for c in compress(query, lc) if c.strip()]
 
239
  ctx = "\n\n---\n\n".join(lc[:k])
240
  r = _generate(PROMPTS[level].format(ctx=ctx, q=query))
241
- return {
242
- "answer": r.text.strip(),
243
- "sources": [{"content": h["content"][:200], "metadata": h.get("metadata", {}),
244
- "score": h.get("similarity", 0), "source": "local"}
245
- for h in hits[:len(lc)]],
246
- "complexity": cls, "retrieval_source": "local",
247
- "cost": calc_cost(_total_input_tokens, _total_output_tokens),
248
- }
249
 
250
  async def run_acra_pipeline(mode, **kw):
251
  if mode == "ingest":
 
10
  GEN_MODEL = "gemini-3.1-flash-lite-preview"
11
  DEPTH = {0: 3, 1: 3, 2: 6, 3: 10}
12
 
13
+ PRICE_IN = 0.075 / 1_000_000
14
+ PRICE_OUT = 0.30 / 1_000_000
15
+ GPT4O_IN = 2.50 / 1_000_000
16
+ GPT4O_OUT = 10.00 / 1_000_000
 
17
 
18
  _total_input_tokens = 0
19
  _total_output_tokens = 0
20
 
21
+ def _get_tokens(usage_metadata):
22
+ """Safely extract tokens β€” field names differ across SDK versions."""
23
+ if not usage_metadata:
24
+ return 0, 0
25
+ in_tok = (
26
+ getattr(usage_metadata, "prompt_token_count", None) or
27
+ getattr(usage_metadata, "input_token_count", None) or
28
+ getattr(usage_metadata, "total_token_count", None) or 0
29
+ )
30
+ out_tok = (
31
+ getattr(usage_metadata, "candidates_token_count", None) or
32
+ getattr(usage_metadata, "output_token_count", None) or 0
33
+ )
34
+ return in_tok, out_tok
35
+
36
+ def _generate(contents, retries=4):
37
+ """Gemini call with retry on 503/429 + robust token tracking."""
38
  global _total_input_tokens, _total_output_tokens
39
+ last_err = None
40
  for attempt in range(retries):
41
  try:
42
  r = client.models.generate_content(model=GEN_MODEL, contents=contents)
43
+ in_tok, out_tok = _get_tokens(getattr(r, "usage_metadata", None))
44
+ _total_input_tokens += in_tok
45
+ _total_output_tokens += out_tok
46
  return r
47
  except Exception as e:
48
+ last_err = e
49
+ err_str = str(e)
50
+ if any(code in err_str for code in ["503", "429", "UNAVAILABLE", "Resource"]):
51
+ wait = 2 ** attempt # 1, 2, 4, 8s
52
+ print(f"Gemini {err_str[:40]} β€” retry {attempt+1}/{retries} in {wait}s")
53
  time.sleep(wait)
54
  else:
55
  raise
56
+ raise RuntimeError(f"Gemini unavailable after {retries} retries: {last_err}")
57
 
58
  def calc_cost(in_tok, out_tok):
59
+ acra = in_tok * PRICE_IN + out_tok * PRICE_OUT
60
+ gpt4o = in_tok * GPT4O_IN + out_tok * GPT4O_OUT
 
61
  return {
62
  "input_tokens": in_tok,
63
  "output_tokens": out_tok,
64
+ "acra_cost_usd": round(acra, 6),
65
+ "gpt4o_cost_usd": round(gpt4o, 6),
66
+ "savings_pct": round((1 - acra / gpt4o) * 100, 1) if gpt4o else 0,
67
  }
68
 
69
  def embed_texts(texts):
 
89
  return chunks or [text]
90
 
91
  def web_search(query: str, max_results: int = 5) -> List[dict]:
92
+ """Jina web search with automatic query simplification fallback."""
93
+ jina_key = os.environ.get("JINA_API_KEY", "")
94
+ queries_to_try = [query, " ".join(query.split()[:8])] # full, then simplified
95
+ for attempt_q in queries_to_try:
96
+ try:
97
+ import urllib.parse
98
+ encoded = urllib.parse.quote(attempt_q)
99
+ r = httpx.get(
100
+ f"https://s.jina.ai/?q={encoded}",
101
+ headers={
102
+ "Authorization": f"Bearer {jina_key}",
103
+ "Accept": "application/json",
104
+ "X-Retain-Images": "none",
105
+ "X-Engine": "direct",
106
+ },
107
+ timeout=25.0,
108
+ follow_redirects=True
109
+ )
110
+ if r.status_code != 200:
111
+ print(f"Jina {r.status_code} on attempt query: {attempt_q[:60]}")
112
+ continue
113
+ items = r.json().get("data", [])
114
+ out = []
115
+ for item in items[:max_results]:
116
+ snippet = item.get("description") or item.get("content", "")
117
+ if snippet:
118
+ out.append({
119
+ "title": item.get("title", ""),
120
+ "snippet": snippet[:600],
121
+ "url": item.get("url", "")
122
+ })
123
+ if out:
124
+ return out
125
+ except Exception as e:
126
+ print(f"Web search error: {e}")
127
+ continue
128
+ return []
129
 
130
  def decompose(query):
131
  r = _generate(f"Decompose into 2-4 simpler sub-queries. Numbered list only.\n\nQuery: {query}")
 
187
  level = cls["level"]
188
  k = DEPTH[level]
189
 
190
+ def _cost():
191
+ return calc_cost(_total_input_tokens, _total_output_tokens)
192
+
193
+ def _web_sources(hits):
194
+ return [{"content": h["snippet"][:200],
195
+ "metadata": {"title": h["title"], "url": h["url"]},
196
+ "score": 1.0, "source": "web"} for h in hits]
197
+
198
+ def _local_sources(hits):
199
+ return [{"content": h["content"][:200],
200
+ "metadata": h.get("metadata", {}),
201
+ "score": h.get("similarity", 0),
202
+ "source": "local"} for h in hits]
203
+
204
+ # ── use_web=True: pure Jina search ──────────────────────────
205
  if use_web:
206
  hits = web_search(query, max_results=6)
207
  if not hits:
208
+ # Last resort: answer from model knowledge
209
+ r = _generate(f"Answer from your knowledge. Be thorough.\n\n{query}")
210
+ return {"answer": r.text.strip(), "sources": [],
211
+ "complexity": cls, "retrieval_source": "model_knowledge",
212
+ "cost": _cost()}
213
  ctx = "\n\n---\n\n".join(
214
  f"Source: {h['title']}\nURL: {h['url']}\n{h['snippet']}" for h in hits)
215
  r = _generate(WEB_PROMPT.format(ctx=ctx, q=query))
216
+ return {"answer": r.text.strip(), "sources": _web_sources(hits),
217
+ "complexity": cls, "retrieval_source": "web", "cost": _cost()}
 
 
 
 
 
 
 
218
 
219
+ # ── L0 ───────────────────────────────────────────────────────
220
  if level == 0:
221
  doc_hits = vsearch(query, namespace, user_id, 2)
222
  if doc_hits:
223
+ ctx = "\n\n---\n\n".join(h["content"] for h in doc_hits)
224
+ r = _generate(f"Use the context if relevant, else answer from knowledge.\n\n"
225
+ f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:")
 
 
226
  top_score = doc_hits[0].get("similarity", 0)
227
+ return {"answer": r.text.strip(),
228
+ "sources": [s for s in _local_sources(doc_hits) if s["score"] > 0.5],
229
+ "complexity": cls,
230
+ "retrieval_source": "local" if top_score > 0.5 else "model_knowledge",
231
+ "cost": _cost()}
 
 
 
 
 
 
232
  r = _generate(f"Answer from your knowledge:\n\n{query}")
233
  return {"answer": r.text.strip(), "sources": [],
234
+ "complexity": cls, "retrieval_source": "model_knowledge", "cost": _cost()}
 
235
 
236
+ # ── L1-L3: local vector search ───────────────────────────────
237
  hits = []
238
  if level == 3:
239
  seen = set()
240
  for sq in decompose(query):
241
  for h in vsearch(sq, namespace, user_id, 4):
242
+ if h["id"] not in seen:
243
+ seen.add(h["id"]); hits.append(h)
244
  else:
245
  hits = vsearch(query, namespace, user_id, k)
246
 
247
+ # Fallback to web if no local docs
248
  if not hits:
249
  web_hits = web_search(query, max_results=k)
250
+ if web_hits:
251
+ ctx = "\n\n---\n\n".join(f"Source: {h['title']}\n{h['snippet']}" for h in web_hits)
252
+ r = _generate(WEB_PROMPT.format(ctx=ctx, q=query))
253
+ return {"answer": r.text.strip(), "sources": _web_sources(web_hits),
254
+ "complexity": cls, "retrieval_source": "web", "cost": _cost()}
255
+ # Final fallback: model knowledge
256
+ r = _generate(f"Answer from your knowledge. Be thorough.\n\n{query}")
257
+ return {"answer": r.text.strip(), "sources": [],
258
+ "complexity": cls, "retrieval_source": "model_knowledge", "cost": _cost()}
 
 
 
 
259
 
260
  lc = [h["content"] for h in hits]
261
+ if rerank and level >= 2:
262
+ lc = [c for c in compress(query, lc) if c.strip()] or lc
263
  ctx = "\n\n---\n\n".join(lc[:k])
264
  r = _generate(PROMPTS[level].format(ctx=ctx, q=query))
265
+ return {"answer": r.text.strip(), "sources": _local_sources(hits[:len(lc)]),
266
+ "complexity": cls, "retrieval_source": "local", "cost": _cost()}
 
 
 
 
 
 
267
 
268
  async def run_acra_pipeline(mode, **kw):
269
  if mode == "ingest":