import os, httpx, time, re from google import genai from google.genai import types from db import supabase from classifier_inference import classify_query from typing import List client = genai.Client(api_key=os.environ["GEMINI_API_KEY"]) EMBED_MODEL = "gemini-embedding-001" GEN_MODEL = "gemini-3.1-flash-lite-preview" DEPTH = {0: 3, 1: 3, 2: 6, 3: 10} PRICE_IN = 0.075 / 1_000_000 PRICE_OUT = 0.30 / 1_000_000 GPT4O_IN = 2.50 / 1_000_000 GPT4O_OUT = 10.00 / 1_000_000 _total_input_tokens = 0 _total_output_tokens = 0 def _get_tokens(usage_metadata): """Safely extract tokens — field names differ across SDK versions.""" if not usage_metadata: return 0, 0 in_tok = ( getattr(usage_metadata, "prompt_token_count", None) or getattr(usage_metadata, "input_token_count", None) or getattr(usage_metadata, "total_token_count", None) or 0 ) out_tok = ( getattr(usage_metadata, "candidates_token_count", None) or getattr(usage_metadata, "output_token_count", None) or 0 ) return in_tok, out_tok def _generate(contents, retries=4): """Gemini call with retry on 503/429 + robust token tracking.""" global _total_input_tokens, _total_output_tokens last_err = None for attempt in range(retries): try: r = client.models.generate_content(model=GEN_MODEL, contents=contents) in_tok, out_tok = _get_tokens(getattr(r, "usage_metadata", None)) _total_input_tokens += in_tok _total_output_tokens += out_tok return r except Exception as e: last_err = e err_str = str(e) if any(code in err_str for code in ["503", "429", "UNAVAILABLE", "Resource"]): wait = 2 ** attempt # 1, 2, 4, 8s print(f"Gemini {err_str[:40]} — retry {attempt+1}/{retries} in {wait}s") time.sleep(wait) else: raise raise RuntimeError(f"Gemini unavailable after {retries} retries: {last_err}") def calc_cost(in_tok, out_tok): acra = in_tok * PRICE_IN + out_tok * PRICE_OUT gpt4o = in_tok * GPT4O_IN + out_tok * GPT4O_OUT return { "input_tokens": in_tok, "output_tokens": out_tok, "acra_cost_usd": round(acra, 6), "gpt4o_cost_usd": round(gpt4o, 6), "savings_pct": round((1 - acra / gpt4o) * 100, 1) if gpt4o else 0, } def embed_texts(texts): r = client.models.embed_content(model=EMBED_MODEL, contents=texts, config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT", output_dimensionality=768)) return [e.values for e in r.embeddings] def embed_query(q): r = client.models.embed_content(model=EMBED_MODEL, contents=[q], config=types.EmbedContentConfig(task_type="RETRIEVAL_QUERY", output_dimensionality=768)) return r.embeddings[0].values def adaptive_chunk(text, max_tok=512): paras = [p.strip() for p in text.split("\n\n") if p.strip()] chunks, cur = [], "" for p in paras: if (len(cur.split()) + len(p.split())) / 0.75 < max_tok: cur = (cur + "\n\n" + p).strip() else: if cur: chunks.append(cur) cur = p if cur: chunks.append(cur) return chunks or [text] def web_search(query: str, max_results: int = 5) -> List[dict]: """Jina web search with automatic query simplification fallback.""" jina_key = os.environ.get("JINA_API_KEY", "") queries_to_try = [query, " ".join(query.split()[:8])] # full, then simplified for attempt_q in queries_to_try: try: import urllib.parse encoded = urllib.parse.quote(attempt_q) r = httpx.get( f"https://s.jina.ai/?q={encoded}", headers={ "Authorization": f"Bearer {jina_key}", "Accept": "application/json", "X-Retain-Images": "none", "X-Engine": "direct", }, timeout=25.0, follow_redirects=True ) if r.status_code != 200: print(f"Jina {r.status_code} on attempt query: {attempt_q[:60]}") continue items = r.json().get("data", []) out = [] for item in items[:max_results]: snippet = item.get("description") or item.get("content", "") if snippet: out.append({ "title": item.get("title", ""), "snippet": snippet[:600], "url": item.get("url", "") }) if out: return out except Exception as e: print(f"Web search error: {e}") continue return [] def decompose(query): r = _generate(f"Decompose into 2-4 simpler sub-queries. Numbered list only.\n\nQuery: {query}") lines = [l.strip().lstrip("1234567890.). ") for l in r.text.strip().split("\n") if l.strip()] return lines[:4] or [query] def compress(query, chunks): if not chunks: return [] numbered = "\n\n".join(f"[{i+1}]\n{c}" for i, c in enumerate(chunks)) r = _generate( f"You have {len(chunks)} text chunks and a query.\n" f"For each chunk extract ONLY sentences relevant to the query.\n" f"Reply as [1] [2] etc.\n\n" f"Query: {query}\n\nChunks:\n{numbered}") out = [] for m in re.finditer(r"\[(\d+)\]\s*(.*?)(?=\[\d+\]|$)", r.text, re.DOTALL): t = m.group(2).strip() if t and t.upper() != "EMPTY": out.append(t) return out if out else chunks def vsearch(query, namespace, user_id, k): return (supabase.rpc("match_documents", { "query_embedding": embed_query(query), "match_count": k, "filter_namespace": namespace, "filter_user_id": user_id, }).execute().data or []) PROMPTS = { 0: "Answer from your knowledge:\n\n{q}", 1: "Answer using ONLY the context. Be concise.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:", 2: "Synthesize the context step by step.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:", 3: "Use chain-of-thought reasoning.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:", } WEB_PROMPT = ( "Answer the question using ONLY the web search results below.\n" "Be factual and concise.\n\n" "Web results:\n{ctx}\n\nQuestion: {q}\nAnswer:" ) async def ingest_pipeline(texts, metadata, namespace, user_id): chunks, meta = [], [] for i, t in enumerate(texts): for j, c in enumerate(adaptive_chunk(t)): chunks.append(c) meta.append({**metadata[i], "source_index": i, "chunk_index": j}) rows = [{"content": c, "embedding": e, "metadata": m, "namespace": namespace, "user_id": user_id} for c, e, m in zip(chunks, embed_texts(chunks), meta)] for i in range(0, len(rows), 50): supabase.table("documents").insert(rows[i:i+50]).execute() return len(chunks) async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False): global _total_input_tokens, _total_output_tokens _total_input_tokens = _total_output_tokens = 0 cls = classify_query(query) level = cls["level"] k = DEPTH[level] def _cost(): return calc_cost(_total_input_tokens, _total_output_tokens) def _web_sources(hits): return [{"content": h["snippet"][:200], "metadata": {"title": h["title"], "url": h["url"]}, "score": 1.0, "source": "web"} for h in hits] def _local_sources(hits): return [{"content": h["content"][:200], "metadata": h.get("metadata", {}), "score": h.get("similarity", 0), "source": "local"} for h in hits] # ── use_web=True: pure Jina search ────────────────────────── if use_web: hits = web_search(query, max_results=6) if not hits: # Last resort: answer from model knowledge r = _generate(f"Answer from your knowledge. Be thorough.\n\n{query}") return {"answer": r.text.strip(), "sources": [], "complexity": cls, "retrieval_source": "model_knowledge", "cost": _cost()} ctx = "\n\n---\n\n".join( f"Source: {h['title']}\nURL: {h['url']}\n{h['snippet']}" for h in hits) r = _generate(WEB_PROMPT.format(ctx=ctx, q=query)) return {"answer": r.text.strip(), "sources": _web_sources(hits), "complexity": cls, "retrieval_source": "web", "cost": _cost()} # ── L0 ─────────────────────────────────────────────────────── if level == 0: doc_hits = vsearch(query, namespace, user_id, 2) if doc_hits: ctx = "\n\n---\n\n".join(h["content"] for h in doc_hits) r = _generate(f"Use the context if relevant, else answer from knowledge.\n\n" f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:") top_score = doc_hits[0].get("similarity", 0) return {"answer": r.text.strip(), "sources": [s for s in _local_sources(doc_hits) if s["score"] > 0.5], "complexity": cls, "retrieval_source": "local" if top_score > 0.5 else "model_knowledge", "cost": _cost()} r = _generate(f"Answer from your knowledge:\n\n{query}") return {"answer": r.text.strip(), "sources": [], "complexity": cls, "retrieval_source": "model_knowledge", "cost": _cost()} # ── L1-L3: local vector search ─────────────────────────────── hits = [] if level == 3: seen = set() for sq in decompose(query): for h in vsearch(sq, namespace, user_id, 4): if h["id"] not in seen: seen.add(h["id"]); hits.append(h) else: hits = vsearch(query, namespace, user_id, k) # Fallback to web if no local docs if not hits: web_hits = web_search(query, max_results=k) if web_hits: ctx = "\n\n---\n\n".join(f"Source: {h['title']}\n{h['snippet']}" for h in web_hits) r = _generate(WEB_PROMPT.format(ctx=ctx, q=query)) return {"answer": r.text.strip(), "sources": _web_sources(web_hits), "complexity": cls, "retrieval_source": "web", "cost": _cost()} # Final fallback: model knowledge r = _generate(f"Answer from your knowledge. Be thorough.\n\n{query}") return {"answer": r.text.strip(), "sources": [], "complexity": cls, "retrieval_source": "model_knowledge", "cost": _cost()} lc = [h["content"] for h in hits] if rerank and level >= 2: lc = [c for c in compress(query, lc) if c.strip()] or lc ctx = "\n\n---\n\n".join(lc[:k]) r = _generate(PROMPTS[level].format(ctx=ctx, q=query)) return {"answer": r.text.strip(), "sources": _local_sources(hits[:len(lc)]), "complexity": cls, "retrieval_source": "local", "cost": _cost()} async def byo_llm_pipeline(query, llm_endpoint, llm_api_key, llm_model, namespace, user_id): """ BYO-LLM pipeline: 1. Ask their LLM what it doesn't know (gap detection) 2. Jina fetches exactly those gaps 3. Gemini cleans raw web mess → clean bullet points 4. Their LLM answers with tiny clean context """ global _total_input_tokens, _total_output_tokens _total_input_tokens = _total_output_tokens = 0 cls = classify_query(query) level = cls["level"] def _cost(): return calc_cost(_total_input_tokens, _total_output_tokens) def _call_their_llm(messages): """Call their OpenAI-compatible endpoint.""" r = httpx.post( llm_endpoint, headers={ "Authorization": f"Bearer {llm_api_key}", "Content-Type": "application/json" }, json={ "model": llm_model, "messages": messages, "max_tokens": 300, "temperature": 0.1, }, timeout=30.0 ) r.raise_for_status() return r.json()["choices"][0]["message"]["content"].strip() # ── Step 1: Ask their LLM what it doesn't know ────────────── gap_prompt = ( f"You will answer a user query. Before answering, identify ONLY what you are " f"uncertain or lack recent data about.\n" f"Reply with max 6 short lines like:\n" f"- I don't know: [specific gap]\n\n" f"Query: {query}\n\n" f"What are your knowledge gaps? 6 lines max, be specific." ) try: gaps_text = _call_their_llm([{"role": "user", "content": gap_prompt}]) print(f"Gaps detected: {gaps_text[:200]}") except Exception as e: print(f"Gap detection failed: {e} — falling back to full query search") gaps_text = query # ── Step 2: Jina fetches exactly those gaps ────────────────── # Extract gap lines and search each one gap_lines = [l.strip().lstrip("- ").replace("I don't know:", "").replace("I am unsure about:", "").strip() for l in gaps_text.split("\n") if l.strip() and len(l.strip()) > 10][:6] all_hits = [] seen_urls = set() for gap in gap_lines or [query]: hits = web_search(gap, max_results=2) for h in hits: if h["url"] not in seen_urls: seen_urls.add(h["url"]) all_hits.append(h) if not all_hits: # No web results — just send query directly to their LLM try: answer = _call_their_llm([{"role": "user", "content": query}]) except Exception as e: answer = f"LLM call failed: {e}" return {"answer": answer, "sources": [], "complexity": cls, "retrieval_source": "model_knowledge", "cost": _cost()} # ── Step 3: Gemini cleans raw web mess ─────────────────────── raw_ctx = "\n\n---\n\n".join( f"Source: {h['title']}\n{h['snippet']}" for h in all_hits) clean_prompt = ( f"You are a data cleaner. Extract ONLY facts relevant to this query.\n" f"Format: bullet points, max 15 words per bullet, no fluff, no URLs.\n" f"Output max 10 bullets total.\n\n" f"Query: {query}\n\nRaw web data:\n{raw_ctx}\n\nClean bullets:" ) clean_r = _generate(clean_prompt) clean_ctx = clean_r.text.strip() print(f"Cleaned context ({len(clean_ctx)} chars):\n{clean_ctx[:300]}") # ── Step 4: Their LLM answers with clean context ───────────── final_messages = [ {"role": "system", "content": "You are a helpful assistant. Use the provided context to answer accurately. " "If context doesn't help, use your own knowledge."}, {"role": "user", "content": f"Context (verified web facts):\n{clean_ctx}\n\nQuestion: {query}\nAnswer:"} ] try: answer = _call_their_llm(final_messages) except Exception as e: # Fallback to Gemini if their LLM fails print(f"Their LLM failed: {e} — falling back to Gemini") r = _generate(f"Context:\n{clean_ctx}\n\nQuestion: {query}\nAnswer:") answer = r.text.strip() return { "answer": answer, "sources": [{"content": h["snippet"][:200], "metadata": {"title": h["title"], "url": h["url"]}, "score": 1.0, "source": "web"} for h in all_hits], "complexity": cls, "retrieval_source": "byo_llm+web", "cost": _cost(), } async def run_acra_pipeline(mode, **kw): if mode == "ingest": return await ingest_pipeline(kw["texts"], kw["metadata"], kw["namespace"], kw["user_id"]) if kw.get("llm_endpoint"): return await byo_llm_pipeline( kw["query"], kw["llm_endpoint"], kw["llm_api_key"], kw["llm_model"], kw["namespace"], kw["user_id"]) return await query_pipeline(kw["query"], kw["namespace"], kw["top_k"], kw["rerank"], kw["user_id"], use_web=kw.get("use_web", False))