File size: 16,750 Bytes
b0344af
ea894c2
 
7c073ac
 
 
 
92d9d21
9d77494
df5b006
9d77494
7c073ac
fd390a0
 
 
 
b0344af
 
 
 
fd390a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0344af
fd390a0
b0344af
 
 
fd390a0
 
 
b0344af
 
fd390a0
 
 
 
 
b0344af
 
 
fd390a0
b0344af
 
fd390a0
 
b0344af
 
 
fd390a0
 
 
b0344af
 
7c073ac
92d9d21
9d77494
92d9d21
7c073ac
 
92d9d21
9d77494
92d9d21
7c073ac
 
 
 
 
 
 
 
 
 
 
 
 
4c3485d
fd390a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea894c2
7c073ac
b0344af
ea894c2
7c073ac
 
 
23f097c
 
b0344af
23f097c
92a11ef
 
b0344af
7c073ac
4c3485d
 
 
23f097c
7c073ac
 
 
ea894c2
 
7c073ac
ea894c2
7c073ac
 
 
92a11ef
ea894c2
 
 
 
4c3485d
 
92a11ef
4c3485d
 
7c073ac
 
 
 
 
92d9d21
 
4c3485d
 
7c073ac
 
 
 
 
ea894c2
b0344af
 
 
9d77494
 
 
 
fd390a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c3485d
92a11ef
 
fd390a0
 
 
 
 
4c3485d
92a11ef
b0344af
fd390a0
 
4c3485d
fd390a0
4c3485d
92a11ef
 
fd390a0
 
 
92a11ef
fd390a0
 
 
 
 
b0344af
4c3485d
fd390a0
9d77494
fd390a0
7c073ac
 
 
 
 
fd390a0
 
7c073ac
 
9d77494
fd390a0
4c3485d
fd88a61
fd390a0
 
 
 
 
 
 
 
 
4c3485d
 
fd390a0
 
4c3485d
b0344af
fd390a0
 
7c073ac
cfb2fb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c073ac
 
4c3485d
 
cfb2fb8
 
 
 
92d9d21
4c3485d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
import os, httpx, time, re
from google import genai
from google.genai import types
from db import supabase
from classifier_inference import classify_query
from typing import List

client      = genai.Client(api_key=os.environ["GEMINI_API_KEY"])
EMBED_MODEL = "gemini-embedding-001"
GEN_MODEL   = "gemini-3.1-flash-lite-preview"
DEPTH       = {0: 3, 1: 3, 2: 6, 3: 10}

PRICE_IN  = 0.075  / 1_000_000
PRICE_OUT = 0.30   / 1_000_000
GPT4O_IN  = 2.50   / 1_000_000
GPT4O_OUT = 10.00  / 1_000_000

_total_input_tokens  = 0
_total_output_tokens = 0

def _get_tokens(usage_metadata):
    """Safely extract tokens β€” field names differ across SDK versions."""
    if not usage_metadata:
        return 0, 0
    in_tok = (
        getattr(usage_metadata, "prompt_token_count", None) or
        getattr(usage_metadata, "input_token_count", None) or
        getattr(usage_metadata, "total_token_count", None) or 0
    )
    out_tok = (
        getattr(usage_metadata, "candidates_token_count", None) or
        getattr(usage_metadata, "output_token_count", None) or 0
    )
    return in_tok, out_tok

def _generate(contents, retries=4):
    """Gemini call with retry on 503/429 + robust token tracking."""
    global _total_input_tokens, _total_output_tokens
    last_err = None
    for attempt in range(retries):
        try:
            r = client.models.generate_content(model=GEN_MODEL, contents=contents)
            in_tok, out_tok = _get_tokens(getattr(r, "usage_metadata", None))
            _total_input_tokens  += in_tok
            _total_output_tokens += out_tok
            return r
        except Exception as e:
            last_err = e
            err_str  = str(e)
            if any(code in err_str for code in ["503", "429", "UNAVAILABLE", "Resource"]):
                wait = 2 ** attempt  # 1, 2, 4, 8s
                print(f"Gemini {err_str[:40]} β€” retry {attempt+1}/{retries} in {wait}s")
                time.sleep(wait)
            else:
                raise
    raise RuntimeError(f"Gemini unavailable after {retries} retries: {last_err}")

def calc_cost(in_tok, out_tok):
    acra  = in_tok * PRICE_IN  + out_tok * PRICE_OUT
    gpt4o = in_tok * GPT4O_IN  + out_tok * GPT4O_OUT
    return {
        "input_tokens":   in_tok,
        "output_tokens":  out_tok,
        "acra_cost_usd":  round(acra,  6),
        "gpt4o_cost_usd": round(gpt4o, 6),
        "savings_pct":    round((1 - acra / gpt4o) * 100, 1) if gpt4o else 0,
    }

def embed_texts(texts):
    r = client.models.embed_content(model=EMBED_MODEL, contents=texts,
        config=types.EmbedContentConfig(task_type="RETRIEVAL_DOCUMENT", output_dimensionality=768))
    return [e.values for e in r.embeddings]

def embed_query(q):
    r = client.models.embed_content(model=EMBED_MODEL, contents=[q],
        config=types.EmbedContentConfig(task_type="RETRIEVAL_QUERY", output_dimensionality=768))
    return r.embeddings[0].values

def adaptive_chunk(text, max_tok=512):
    paras = [p.strip() for p in text.split("\n\n") if p.strip()]
    chunks, cur = [], ""
    for p in paras:
        if (len(cur.split()) + len(p.split())) / 0.75 < max_tok:
            cur = (cur + "\n\n" + p).strip()
        else:
            if cur: chunks.append(cur)
            cur = p
    if cur: chunks.append(cur)
    return chunks or [text]

def web_search(query: str, max_results: int = 5) -> List[dict]:
    """Jina web search with automatic query simplification fallback."""
    jina_key = os.environ.get("JINA_API_KEY", "")
    queries_to_try = [query, " ".join(query.split()[:8])]  # full, then simplified
    for attempt_q in queries_to_try:
        try:
            import urllib.parse
            encoded = urllib.parse.quote(attempt_q)
            r = httpx.get(
                f"https://s.jina.ai/?q={encoded}",
                headers={
                    "Authorization": f"Bearer {jina_key}",
                    "Accept": "application/json",
                    "X-Retain-Images": "none",
                    "X-Engine": "direct",
                },
                timeout=25.0,
                follow_redirects=True
            )
            if r.status_code != 200:
                print(f"Jina {r.status_code} on attempt query: {attempt_q[:60]}")
                continue
            items = r.json().get("data", [])
            out = []
            for item in items[:max_results]:
                snippet = item.get("description") or item.get("content", "")
                if snippet:
                    out.append({
                        "title":   item.get("title", ""),
                        "snippet": snippet[:600],
                        "url":     item.get("url", "")
                    })
            if out:
                return out
        except Exception as e:
            print(f"Web search error: {e}")
            continue
    return []

def decompose(query):
    r = _generate(f"Decompose into 2-4 simpler sub-queries. Numbered list only.\n\nQuery: {query}")
    lines = [l.strip().lstrip("1234567890.). ") for l in r.text.strip().split("\n") if l.strip()]
    return lines[:4] or [query]

def compress(query, chunks):
    if not chunks: return []
    numbered = "\n\n".join(f"[{i+1}]\n{c}" for i, c in enumerate(chunks))
    r = _generate(
        f"You have {len(chunks)} text chunks and a query.\n"
        f"For each chunk extract ONLY sentences relevant to the query.\n"
        f"Reply as [1] <text or EMPTY> [2] <text or EMPTY> etc.\n\n"
        f"Query: {query}\n\nChunks:\n{numbered}")
    out = []
    for m in re.finditer(r"\[(\d+)\]\s*(.*?)(?=\[\d+\]|$)", r.text, re.DOTALL):
        t = m.group(2).strip()
        if t and t.upper() != "EMPTY": out.append(t)
    return out if out else chunks

def vsearch(query, namespace, user_id, k):
    return (supabase.rpc("match_documents", {
        "query_embedding": embed_query(query),
        "match_count": k,
        "filter_namespace": namespace,
        "filter_user_id": user_id,
    }).execute().data or [])

PROMPTS = {
    0: "Answer from your knowledge:\n\n{q}",
    1: "Answer using ONLY the context. Be concise.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
    2: "Synthesize the context step by step.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
    3: "Use chain-of-thought reasoning.\n\nContext:\n{ctx}\n\nQuestion: {q}\nAnswer:",
}
WEB_PROMPT = (
    "Answer the question using ONLY the web search results below.\n"
    "Be factual and concise.\n\n"
    "Web results:\n{ctx}\n\nQuestion: {q}\nAnswer:"
)

async def ingest_pipeline(texts, metadata, namespace, user_id):
    chunks, meta = [], []
    for i, t in enumerate(texts):
        for j, c in enumerate(adaptive_chunk(t)):
            chunks.append(c)
            meta.append({**metadata[i], "source_index": i, "chunk_index": j})
    rows = [{"content": c, "embedding": e, "metadata": m,
             "namespace": namespace, "user_id": user_id}
            for c, e, m in zip(chunks, embed_texts(chunks), meta)]
    for i in range(0, len(rows), 50):
        supabase.table("documents").insert(rows[i:i+50]).execute()
    return len(chunks)

async def query_pipeline(query, namespace, top_k, rerank, user_id, use_web=False):
    global _total_input_tokens, _total_output_tokens
    _total_input_tokens = _total_output_tokens = 0

    cls   = classify_query(query)
    level = cls["level"]
    k     = DEPTH[level]

    def _cost():
        return calc_cost(_total_input_tokens, _total_output_tokens)

    def _web_sources(hits):
        return [{"content": h["snippet"][:200],
                 "metadata": {"title": h["title"], "url": h["url"]},
                 "score": 1.0, "source": "web"} for h in hits]

    def _local_sources(hits):
        return [{"content": h["content"][:200],
                 "metadata": h.get("metadata", {}),
                 "score": h.get("similarity", 0),
                 "source": "local"} for h in hits]

    # ── use_web=True: pure Jina search ──────────────────────────
    if use_web:
        hits = web_search(query, max_results=6)
        if not hits:
            # Last resort: answer from model knowledge
            r = _generate(f"Answer from your knowledge. Be thorough.\n\n{query}")
            return {"answer": r.text.strip(), "sources": [],
                    "complexity": cls, "retrieval_source": "model_knowledge",
                    "cost": _cost()}
        ctx = "\n\n---\n\n".join(
            f"Source: {h['title']}\nURL: {h['url']}\n{h['snippet']}" for h in hits)
        r = _generate(WEB_PROMPT.format(ctx=ctx, q=query))
        return {"answer": r.text.strip(), "sources": _web_sources(hits),
                "complexity": cls, "retrieval_source": "web", "cost": _cost()}

    # ── L0 ───────────────────────────────────────────────────────
    if level == 0:
        doc_hits = vsearch(query, namespace, user_id, 2)
        if doc_hits:
            ctx       = "\n\n---\n\n".join(h["content"] for h in doc_hits)
            r         = _generate(f"Use the context if relevant, else answer from knowledge.\n\n"
                                   f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:")
            top_score = doc_hits[0].get("similarity", 0)
            return {"answer": r.text.strip(),
                    "sources": [s for s in _local_sources(doc_hits) if s["score"] > 0.5],
                    "complexity": cls,
                    "retrieval_source": "local" if top_score > 0.5 else "model_knowledge",
                    "cost": _cost()}
        r = _generate(f"Answer from your knowledge:\n\n{query}")
        return {"answer": r.text.strip(), "sources": [],
                "complexity": cls, "retrieval_source": "model_knowledge", "cost": _cost()}

    # ── L1-L3: local vector search ───────────────────────────────
    hits = []
    if level == 3:
        seen = set()
        for sq in decompose(query):
            for h in vsearch(sq, namespace, user_id, 4):
                if h["id"] not in seen:
                    seen.add(h["id"]); hits.append(h)
    else:
        hits = vsearch(query, namespace, user_id, k)

    # Fallback to web if no local docs
    if not hits:
        web_hits = web_search(query, max_results=k)
        if web_hits:
            ctx = "\n\n---\n\n".join(f"Source: {h['title']}\n{h['snippet']}" for h in web_hits)
            r   = _generate(WEB_PROMPT.format(ctx=ctx, q=query))
            return {"answer": r.text.strip(), "sources": _web_sources(web_hits),
                    "complexity": cls, "retrieval_source": "web", "cost": _cost()}
        # Final fallback: model knowledge
        r = _generate(f"Answer from your knowledge. Be thorough.\n\n{query}")
        return {"answer": r.text.strip(), "sources": [],
                "complexity": cls, "retrieval_source": "model_knowledge", "cost": _cost()}

    lc = [h["content"] for h in hits]
    if rerank and level >= 2:
        lc = [c for c in compress(query, lc) if c.strip()] or lc
    ctx = "\n\n---\n\n".join(lc[:k])
    r   = _generate(PROMPTS[level].format(ctx=ctx, q=query))
    return {"answer": r.text.strip(), "sources": _local_sources(hits[:len(lc)]),
            "complexity": cls, "retrieval_source": "local", "cost": _cost()}


async def byo_llm_pipeline(query, llm_endpoint, llm_api_key, llm_model, namespace, user_id):
    """
    BYO-LLM pipeline:
    1. Ask their LLM what it doesn't know (gap detection)
    2. Jina fetches exactly those gaps
    3. Gemini cleans raw web mess β†’ clean bullet points
    4. Their LLM answers with tiny clean context
    """
    global _total_input_tokens, _total_output_tokens
    _total_input_tokens = _total_output_tokens = 0

    cls   = classify_query(query)
    level = cls["level"]

    def _cost():
        return calc_cost(_total_input_tokens, _total_output_tokens)

    def _call_their_llm(messages):
        """Call their OpenAI-compatible endpoint."""
        r = httpx.post(
            llm_endpoint,
            headers={
                "Authorization": f"Bearer {llm_api_key}",
                "Content-Type": "application/json"
            },
            json={
                "model": llm_model,
                "messages": messages,
                "max_tokens": 300,
                "temperature": 0.1,
            },
            timeout=30.0
        )
        r.raise_for_status()
        return r.json()["choices"][0]["message"]["content"].strip()

    # ── Step 1: Ask their LLM what it doesn't know ──────────────
    gap_prompt = (
        f"You will answer a user query. Before answering, identify ONLY what you are "
        f"uncertain or lack recent data about.\n"
        f"Reply with max 6 short lines like:\n"
        f"- I don't know: [specific gap]\n\n"
        f"Query: {query}\n\n"
        f"What are your knowledge gaps? 6 lines max, be specific."
    )
    try:
        gaps_text = _call_their_llm([{"role": "user", "content": gap_prompt}])
        print(f"Gaps detected: {gaps_text[:200]}")
    except Exception as e:
        print(f"Gap detection failed: {e} β€” falling back to full query search")
        gaps_text = query

    # ── Step 2: Jina fetches exactly those gaps ──────────────────
    # Extract gap lines and search each one
    gap_lines = [l.strip().lstrip("- ").replace("I don't know:", "").replace("I am unsure about:", "").strip()
                 for l in gaps_text.split("\n") if l.strip() and len(l.strip()) > 10][:6]

    all_hits = []
    seen_urls = set()
    for gap in gap_lines or [query]:
        hits = web_search(gap, max_results=2)
        for h in hits:
            if h["url"] not in seen_urls:
                seen_urls.add(h["url"])
                all_hits.append(h)

    if not all_hits:
        # No web results β€” just send query directly to their LLM
        try:
            answer = _call_their_llm([{"role": "user", "content": query}])
        except Exception as e:
            answer = f"LLM call failed: {e}"
        return {"answer": answer, "sources": [], "complexity": cls,
                "retrieval_source": "model_knowledge", "cost": _cost()}

    # ── Step 3: Gemini cleans raw web mess ───────────────────────
    raw_ctx = "\n\n---\n\n".join(
        f"Source: {h['title']}\n{h['snippet']}" for h in all_hits)

    clean_prompt = (
        f"You are a data cleaner. Extract ONLY facts relevant to this query.\n"
        f"Format: bullet points, max 15 words per bullet, no fluff, no URLs.\n"
        f"Output max 10 bullets total.\n\n"
        f"Query: {query}\n\nRaw web data:\n{raw_ctx}\n\nClean bullets:"
    )
    clean_r  = _generate(clean_prompt)
    clean_ctx = clean_r.text.strip()
    print(f"Cleaned context ({len(clean_ctx)} chars):\n{clean_ctx[:300]}")

    # ── Step 4: Their LLM answers with clean context ─────────────
    final_messages = [
        {"role": "system", "content":
            "You are a helpful assistant. Use the provided context to answer accurately. "
            "If context doesn't help, use your own knowledge."},
        {"role": "user", "content":
            f"Context (verified web facts):\n{clean_ctx}\n\nQuestion: {query}\nAnswer:"}
    ]
    try:
        answer = _call_their_llm(final_messages)
    except Exception as e:
        # Fallback to Gemini if their LLM fails
        print(f"Their LLM failed: {e} β€” falling back to Gemini")
        r      = _generate(f"Context:\n{clean_ctx}\n\nQuestion: {query}\nAnswer:")
        answer = r.text.strip()

    return {
        "answer":           answer,
        "sources":          [{"content": h["snippet"][:200],
                              "metadata": {"title": h["title"], "url": h["url"]},
                              "score": 1.0, "source": "web"} for h in all_hits],
        "complexity":       cls,
        "retrieval_source": "byo_llm+web",
        "cost":             _cost(),
    }

async def run_acra_pipeline(mode, **kw):
    if mode == "ingest":
        return await ingest_pipeline(kw["texts"], kw["metadata"],
                                     kw["namespace"], kw["user_id"])
    if kw.get("llm_endpoint"):
        return await byo_llm_pipeline(
            kw["query"], kw["llm_endpoint"], kw["llm_api_key"],
            kw["llm_model"], kw["namespace"], kw["user_id"])
    return await query_pipeline(kw["query"], kw["namespace"], kw["top_k"],
                                kw["rerank"], kw["user_id"],
                                use_web=kw.get("use_web", False))