| import os, httpx, time, re |
| from google import genai |
| from google.genai import types |
| from db import supabase |
| from classifier_inference import classify_query |
| from typing import List |
|
|
| client = genai.Client(api_key=os.environ["GEMINI_API_KEY"]) |
| EMBED_MODEL = "gemini-embedding-001" |
| GEN_MODEL = "gemini-3.1-flash-lite-preview" |
| DEPTH = {0: 3, 1: 3, 2: 6, 3: 10} |
|
|
| PRICE_IN = 0.075 / 1_000_000 |
| PRICE_OUT = 0.30 / 1_000_000 |
| GPT4O_IN = 2.50 / 1_000_000 |
| GPT4O_OUT = 10.00 / 1_000_000 |
|
|
| _total_input_tokens = 0 |
| _total_output_tokens = 0 |
|
|
| def _get_tokens(usage_metadata): |
| """Safely extract tokens β field names differ across SDK versions.""" |
| if not usage_metadata: |
| return 0, 0 |
| in_tok = ( |
| getattr(usage_metadata, "prompt_token_count", None) or |
| getattr(usage_metadata, "input_token_count", None) or |
| getattr(usage_metadata, "total_token_count", None) or 0 |
| ) |
| out_tok = ( |
| getattr(usage_metadata, "candidates_token_count", None) or |
| getattr(usage_metadata, "output_token_count", None) or 0 |
| ) |
| return in_tok, out_tok |
|
|
| def _generate(contents, retries=4): |
| """Gemini call with retry on 503/429 + robust token tracking.""" |
| global _total_input_tokens, _total_output_tokens |
| last_err = None |
| for attempt in range(retries): |
| try: |
| r = client.models.generate_content(model=GEN_MODEL, contents=contents) |
| in_tok, out_tok = _get_tokens(getattr(r, "usage_metadata", None)) |
| _total_input_tokens += in_tok |
| _total_output_tokens += out_tok |
| return r |
| except Exception as e: |
| last_err = e |
| err_str = str(e) |
| if any(code in err_str for code in ["503", "429", "UNAVAILABLE", "Resource"]): |
| wait = 2 ** attempt |
| print(f"Gemini {err_str[:40]} β retry {attempt+1}/{retries} in {wait}s") |
| time.sleep(wait) |
| else: |
| raise |
| raise RuntimeError(f"Gemini unavailable after {retries} retries: {last_err}") |
|
|
| def calc_cost(in_tok, out_tok): |
| acra = in_tok * PRICE_IN + out_tok * PRICE_OUT |
| gpt4o = in_tok * GPT4O_IN + out_tok * GPT4O_OUT |
| return { |
| "input_tokens": in_tok, |
| "output_tokens": out_tok, |
| "acra_cost_usd": round(acra, 6), |
| "gpt4o_cost_usd": round(gpt4o, 6), |
| "savings_pct": round((1 - acra / gpt4o) * 100, 1) if gpt4o else 0, |
| } |
|
|
| def embed_texts(texts): |
| r = client.models.embed_content(model=EMBED_MODEL, contents=texts, |
| config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT", output_dimensionality=768)) |
| return [e.values for e in r.embeddings] |
|
|
| def embed_query(q): |
| r = client.models.embed_content(model=EMBED_MODEL, contents=[q], |
| config=types.EmbedContentConfig(task_type="RETRIEVAL_QUERY", output_dimensionality=768)) |
| return r.embeddings[0].values |
|
|
| def adaptive_chunk(text, max_tok=512): |
| paras = [p.strip() for p in text.split("\n\n") if p.strip()] |
| chunks, cur = [], "" |
| for p in paras: |
| if (len(cur.split()) + len(p.split())) / 0.75 < max_tok: |
| cur = (cur + "\n\n" + p).strip() |
| else: |
| if cur: chunks.append(cur) |
| cur = p |
| if cur: chunks.append(cur) |
| return chunks or [text] |
|
|
| def web_search(query: str, max_results: int = 5) -> List[dict]: |
| """Jina web search with automatic query simplification fallback.""" |
| jina_key = os.environ.get("JINA_API_KEY", "") |
| queries_to_try = [query, " ".join(query.split()[:8])] |
| for attempt_q in queries_to_try: |
| try: |
| import urllib.parse |
| encoded = urllib.parse.quote(attempt_q) |
| r = httpx.get( |
| f"https://s.jina.ai/?q={encoded}", |
| headers={ |
| "Authorization": f"Bearer {jina_key}", |
| "Accept": "application/json", |
| "X-Retain-Images": "none", |
| "X-Engine": "direct", |
| }, |
| timeout=25.0, |
| follow_redirects=True |
| ) |
| if r.status_code != 200: |
| print(f"Jina {r.status_code} on attempt query: {attempt_q[:60]}") |
| continue |
| items = r.json().get("data", []) |
| out = [] |
| for item in items[:max_results]: |
| snippet = item.get("description") or item.get("content", "") |
| if snippet: |
| out.append({ |
| "title": item.get("title", ""), |
| "snippet": snippet[:600], |
| "url": item.get("url", "") |
| }) |
| if out: |
| return out |
| except Exception as e: |
| print(f"Web search error: {e}") |
| continue |
| return [] |
|
|
| def decompose(query): |
| r = _generate(f"Decompose into 2-4 simpler sub-queries. Numbered list only.\n\nQuery: {query}") |
| lines = [l.strip().lstrip("1234567890.). ") for l in r.text.strip().split("\n") if l.strip()] |
| return lines[:4] or [query] |
|
|
| def compress(query, chunks): |
| if not chunks: return [] |
| numbered = "\n\n".join(f"[{i+1}]\n{c}" for i, c in enumerate(chunks)) |
| r = _generate( |
| f"You have {len(chunks)} text chunks and a query.\n" |
| f"For each chunk extract ONLY sentences relevant to the query.\n" |
| f"Reply as [1] <text or EMPTY> [2] <text or EMPTY> etc.\n\n" |
| f"Query: {query}\n\nChunks:\n{numbered}") |
| out = [] |
| for m in re.finditer(r"\[(\d+)\]\s*(.*?)(?=\[\d+\]|$)", r.text, re.DOTALL): |
| t = m.group(2).strip() |
| if t and t.upper() != "EMPTY": out.append(t) |
| return out if out else chunks |
|
|
| def vsearch(query, namespace, user_id, k): |
| return (supabase.rpc("match_documents", { |
| "query_embedding": embed_query(query), |
| "match_count": k, |
| "filter_namespace": namespace, |
| "filter_user_id": user_id, |
| }).execute().data or []) |
|
|
| PROMPTS = { |
| 0: "Answer from your knowledge:\n\n{q}", |
| 1: "Answer using ONLY the context. Be concise.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:", |
| 2: "Synthesize the context step by step.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:", |
| 3: "Use chain-of-thought reasoning.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:", |
| } |
| WEB_PROMPT = ( |
| "Answer the question using ONLY the web search results below.\n" |
| "Be factual and concise.\n\n" |
| "Web results:\n{ctx}\n\nQuestion: {q}\nAnswer:" |
| ) |
|
|
| async def ingest_pipeline(texts, metadata, namespace, user_id): |
| chunks, meta = [], [] |
| for i, t in enumerate(texts): |
| for j, c in enumerate(adaptive_chunk(t)): |
| chunks.append(c) |
| meta.append({**metadata[i], "source_index": i, "chunk_index": j}) |
| rows = [{"content": c, "embedding": e, "metadata": m, |
| "namespace": namespace, "user_id": user_id} |
| for c, e, m in zip(chunks, embed_texts(chunks), meta)] |
| for i in range(0, len(rows), 50): |
| supabase.table("documents").insert(rows[i:i+50]).execute() |
| return len(chunks) |
|
|
| async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False): |
| global _total_input_tokens, _total_output_tokens |
| _total_input_tokens = _total_output_tokens = 0 |
|
|
| cls = classify_query(query) |
| level = cls["level"] |
| k = DEPTH[level] |
|
|
| def _cost(): |
| return calc_cost(_total_input_tokens, _total_output_tokens) |
|
|
| def _web_sources(hits): |
| return [{"content": h["snippet"][:200], |
| "metadata": {"title": h["title"], "url": h["url"]}, |
| "score": 1.0, "source": "web"} for h in hits] |
|
|
| def _local_sources(hits): |
| return [{"content": h["content"][:200], |
| "metadata": h.get("metadata", {}), |
| "score": h.get("similarity", 0), |
| "source": "local"} for h in hits] |
|
|
| |
| if use_web: |
| hits = web_search(query, max_results=6) |
| if not hits: |
| |
| r = _generate(f"Answer from your knowledge. Be thorough.\n\n{query}") |
| return {"answer": r.text.strip(), "sources": [], |
| "complexity": cls, "retrieval_source": "model_knowledge", |
| "cost": _cost()} |
| ctx = "\n\n---\n\n".join( |
| f"Source: {h['title']}\nURL: {h['url']}\n{h['snippet']}" for h in hits) |
| r = _generate(WEB_PROMPT.format(ctx=ctx, q=query)) |
| return {"answer": r.text.strip(), "sources": _web_sources(hits), |
| "complexity": cls, "retrieval_source": "web", "cost": _cost()} |
|
|
| |
| if level == 0: |
| doc_hits = vsearch(query, namespace, user_id, 2) |
| if doc_hits: |
| ctx = "\n\n---\n\n".join(h["content"] for h in doc_hits) |
| r = _generate(f"Use the context if relevant, else answer from knowledge.\n\n" |
| f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:") |
| top_score = doc_hits[0].get("similarity", 0) |
| return {"answer": r.text.strip(), |
| "sources": [s for s in _local_sources(doc_hits) if s["score"] > 0.5], |
| "complexity": cls, |
| "retrieval_source": "local" if top_score > 0.5 else "model_knowledge", |
| "cost": _cost()} |
| r = _generate(f"Answer from your knowledge:\n\n{query}") |
| return {"answer": r.text.strip(), "sources": [], |
| "complexity": cls, "retrieval_source": "model_knowledge", "cost": _cost()} |
|
|
| |
| hits = [] |
| if level == 3: |
| seen = set() |
| for sq in decompose(query): |
| for h in vsearch(sq, namespace, user_id, 4): |
| if h["id"] not in seen: |
| seen.add(h["id"]); hits.append(h) |
| else: |
| hits = vsearch(query, namespace, user_id, k) |
|
|
| |
| if not hits: |
| web_hits = web_search(query, max_results=k) |
| if web_hits: |
| ctx = "\n\n---\n\n".join(f"Source: {h['title']}\n{h['snippet']}" for h in web_hits) |
| r = _generate(WEB_PROMPT.format(ctx=ctx, q=query)) |
| return {"answer": r.text.strip(), "sources": _web_sources(web_hits), |
| "complexity": cls, "retrieval_source": "web", "cost": _cost()} |
| |
| r = _generate(f"Answer from your knowledge. Be thorough.\n\n{query}") |
| return {"answer": r.text.strip(), "sources": [], |
| "complexity": cls, "retrieval_source": "model_knowledge", "cost": _cost()} |
|
|
| lc = [h["content"] for h in hits] |
| if rerank and level >= 2: |
| lc = [c for c in compress(query, lc) if c.strip()] or lc |
| ctx = "\n\n---\n\n".join(lc[:k]) |
| r = _generate(PROMPTS[level].format(ctx=ctx, q=query)) |
| return {"answer": r.text.strip(), "sources": _local_sources(hits[:len(lc)]), |
| "complexity": cls, "retrieval_source": "local", "cost": _cost()} |
|
|
|
|
| async def byo_llm_pipeline(query, llm_endpoint, llm_api_key, llm_model, namespace, user_id): |
| """ |
| BYO-LLM pipeline: |
| 1. Ask their LLM what it doesn't know (gap detection) |
| 2. Jina fetches exactly those gaps |
| 3. Gemini cleans raw web mess β clean bullet points |
| 4. Their LLM answers with tiny clean context |
| """ |
| global _total_input_tokens, _total_output_tokens |
| _total_input_tokens = _total_output_tokens = 0 |
|
|
| cls = classify_query(query) |
| level = cls["level"] |
|
|
| def _cost(): |
| return calc_cost(_total_input_tokens, _total_output_tokens) |
|
|
| def _call_their_llm(messages): |
| """Call their OpenAI-compatible endpoint.""" |
| r = httpx.post( |
| llm_endpoint, |
| headers={ |
| "Authorization": f"Bearer {llm_api_key}", |
| "Content-Type": "application/json" |
| }, |
| json={ |
| "model": llm_model, |
| "messages": messages, |
| "max_tokens": 300, |
| "temperature": 0.1, |
| }, |
| timeout=30.0 |
| ) |
| r.raise_for_status() |
| return r.json()["choices"][0]["message"]["content"].strip() |
|
|
| |
| gap_prompt = ( |
| f"You will answer a user query. Before answering, identify ONLY what you are " |
| f"uncertain or lack recent data about.\n" |
| f"Reply with max 6 short lines like:\n" |
| f"- I don't know: [specific gap]\n\n" |
| f"Query: {query}\n\n" |
| f"What are your knowledge gaps? 6 lines max, be specific." |
| ) |
| try: |
| gaps_text = _call_their_llm([{"role": "user", "content": gap_prompt}]) |
| print(f"Gaps detected: {gaps_text[:200]}") |
| except Exception as e: |
| print(f"Gap detection failed: {e} β falling back to full query search") |
| gaps_text = query |
|
|
| |
| |
| gap_lines = [l.strip().lstrip("- ").replace("I don't know:", "").replace("I am unsure about:", "").strip() |
| for l in gaps_text.split("\n") if l.strip() and len(l.strip()) > 10][:6] |
|
|
| all_hits = [] |
| seen_urls = set() |
| for gap in gap_lines or [query]: |
| hits = web_search(gap, max_results=2) |
| for h in hits: |
| if h["url"] not in seen_urls: |
| seen_urls.add(h["url"]) |
| all_hits.append(h) |
|
|
| if not all_hits: |
| |
| try: |
| answer = _call_their_llm([{"role": "user", "content": query}]) |
| except Exception as e: |
| answer = f"LLM call failed: {e}" |
| return {"answer": answer, "sources": [], "complexity": cls, |
| "retrieval_source": "model_knowledge", "cost": _cost()} |
|
|
| |
| raw_ctx = "\n\n---\n\n".join( |
| f"Source: {h['title']}\n{h['snippet']}" for h in all_hits) |
|
|
| clean_prompt = ( |
| f"You are a data cleaner. Extract ONLY facts relevant to this query.\n" |
| f"Format: bullet points, max 15 words per bullet, no fluff, no URLs.\n" |
| f"Output max 10 bullets total.\n\n" |
| f"Query: {query}\n\nRaw web data:\n{raw_ctx}\n\nClean bullets:" |
| ) |
| clean_r = _generate(clean_prompt) |
| clean_ctx = clean_r.text.strip() |
| print(f"Cleaned context ({len(clean_ctx)} chars):\n{clean_ctx[:300]}") |
|
|
| |
| final_messages = [ |
| {"role": "system", "content": |
| "You are a helpful assistant. Use the provided context to answer accurately. " |
| "If context doesn't help, use your own knowledge."}, |
| {"role": "user", "content": |
| f"Context (verified web facts):\n{clean_ctx}\n\nQuestion: {query}\nAnswer:"} |
| ] |
| try: |
| answer = _call_their_llm(final_messages) |
| except Exception as e: |
| |
| print(f"Their LLM failed: {e} β falling back to Gemini") |
| r = _generate(f"Context:\n{clean_ctx}\n\nQuestion: {query}\nAnswer:") |
| answer = r.text.strip() |
|
|
| return { |
| "answer": answer, |
| "sources": [{"content": h["snippet"][:200], |
| "metadata": {"title": h["title"], "url": h["url"]}, |
| "score": 1.0, "source": "web"} for h in all_hits], |
| "complexity": cls, |
| "retrieval_source": "byo_llm+web", |
| "cost": _cost(), |
| } |
|
|
| async def run_acra_pipeline(mode, **kw): |
| if mode == "ingest": |
| return await ingest_pipeline(kw["texts"], kw["metadata"], |
| kw["namespace"], kw["user_id"]) |
| if kw.get("llm_endpoint"): |
| return await byo_llm_pipeline( |
| kw["query"], kw["llm_endpoint"], kw["llm_api_key"], |
| kw["llm_model"], kw["namespace"], kw["user_id"]) |
| return await query_pipeline(kw["query"], kw["namespace"], kw["top_k"], |
| kw["rerank"], kw["user_id"], |
| use_web=kw.get("use_web", False)) |
|
|