File size: 6,674 Bytes
e2ffd2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import os, pickle, json, gradio as gr
import numpy as np, faiss
from sentence_transformers import SentenceTransformer

# ---------- Optional LLM (OpenAI) ----------
OPENAI_API_KEY = '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'  # add in Space -> Settings -> Secrets
USE_OPENAI = bool(OPENAI_API_KEY)
if USE_OPENAI:
    try:
        from openai import OpenAI
        oai = OpenAI(api_key=OPENAI_API_KEY)
        OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
    except Exception as e:
        print("[RAG] OpenAI not available:", e)
        USE_OPENAI = False

# ---------- Artifacts you already have ----------
FAISS_PATH = os.getenv("FAISS_PATH", "squad_v2.faiss")
META_PATH  = os.getenv("META_PATH",  "squad_v2_meta.pkl")

CACHE = {
    "index": None,
    "contexts": None,
    "encoder": None,
    "model_name": None
}

def _coerce_text_list(x):
    """Accepts list[str] or list[dict]; extracts text nicely."""
    out = []
    if isinstance(x, list):
        for it in x:
            if isinstance(it, str):
                out.append(it)
            elif isinstance(it, dict):
                # common keys people use
                text = it.get("text") or it.get("content") or it.get("ctx") or ""
                if text:
                    out.append(text)
    return out

def load_artifacts():
    if CACHE["index"] is not None:
        return

    # 1) FAISS
    if not os.path.exists(FAISS_PATH):
        raise FileNotFoundError(f"Missing FAISS index: {FAISS_PATH}")
    index = faiss.read_index(FAISS_PATH)

    # 2) META
    if not os.path.exists(META_PATH):
        raise FileNotFoundError(f"Missing meta file: {META_PATH}")
    with open(META_PATH, "rb") as f:
        meta = pickle.load(f)

    # parse meta
    model_name = "all-MiniLM-L6-v2"
    contexts = None

    if isinstance(meta, dict):
        # common keys
        model_name = meta.get("model") or meta.get("encoder") or model_name
        contexts = (
            meta.get("contexts")
            or meta.get("texts")
            or meta.get("documents")
            or meta.get("corpus")
        )
    else:
        # meta is just a list of contexts
        contexts = meta

    # normalize contexts
    contexts = _coerce_text_list(contexts) if contexts is not None else []
    if not contexts:
        raise ValueError("No contexts found in meta; expected a list of texts.")

    # Align lengths (safeguard)
    ntotal = index.ntotal
    if ntotal != len(contexts):
        m = min(ntotal, len(contexts))
        print(f"[RAG] WARNING: index.ntotal({ntotal}) != contexts({len(contexts)}). Trimming to {m}.")
        # We can’t resize FAISS easily here; instead trim contexts so we never index out of range.
        contexts = contexts[:m]

    # 3) load encoder (lazy; we instantiate now to avoid first-click delay)
    encoder = SentenceTransformer(model_name)

    CACHE.update(index=index, contexts=contexts, encoder=encoder, model_name=model_name)
    print(f"[RAG] Loaded index={FAISS_PATH} (ntotal={CACHE['index'].ntotal}), "
          f"contexts={len(CACHE['contexts'])}, model={CACHE['model_name']}")

def _retrieve(question: str, k: int):
    # encode query; FAISS expects float32
    q_emb = CACHE["encoder"].encode([question]).astype("float32")
    D, I = CACHE["index"].search(q_emb, int(k))
    idxs = I[0].tolist()
    dists = D[0].tolist()
    # guard for any out-of-range due to mismatched sizes
    max_ok = len(CACHE["contexts"]) - 1
    pairs = []
    for j, dist in zip(idxs, dists):
        if 0 <= j <= max_ok:
            pairs.append((j, dist, CACHE["contexts"][j]))
    return pairs

def _build_prompt(question: str, pairs):
    chunks = []
    for i, (_, _d, ctx) in enumerate(pairs, start=1):
        # keep prompt size reasonable
        ctx_short = ctx.strip()
        if len(ctx_short) > 1200:
            ctx_short = ctx_short[:1200] + "..."
        chunks.append(f"[Source {i}] {ctx_short}")
    context_block = "\n\n".join(chunks) if chunks else "(no context)"
    prompt = f"""Answer strictly from the context below. If not answerable, say so.

Include [Source X] citations in your answer.



Context:

{context_block}



Question: {question}

Answer:"""
    return prompt

def answer(question: str, k: int):
    if not question.strip():
        return "Please enter a question.", [], None

    pairs = _retrieve(question, k)
    if not pairs:
        return "No results found in the index.", [], None

    # Build citations list for UI
    citations = [{"rank": i+1, "faiss_dist": round(d, 4), "snippet": ctx[:240] + ("..." if len(ctx) > 240 else "")}
                 for i, (_idx, d, ctx) in enumerate(pairs)]

    if USE_OPENAI:
        prompt = _build_prompt(question, pairs)
        try:
            resp = oai.chat.completions.create(
                model=OPENAI_MODEL,
                messages=[{"role":"user","content":prompt}],
                temperature=0.2
            )
            ans = resp.choices[0].message.content
        except Exception as e:
            ans = f"LLM call failed: {e}\n\nTop results shown below."
    else:
        # Fallback: show top-1 context as the “answer”
        ans = ("(No OPENAI_API_KEY set — showing most relevant context instead.)\n\n" +
               pairs[0][2][:1200])

    # simple JSON for debugging/export
    raw = {
        "k": int(k),
        "answer": ans,
        "citations": citations
    }
    return ans, citations, json.dumps(raw, indent=2)

# ---------- UI ----------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("## Nyxion Labs · Grounded Q&A (SQuAD v2, FAISS)")

    with gr.Row():
        q = gr.Textbox(label="Ask a question", placeholder="e.g., What is the capital of France?")
        k = gr.Slider(1, 10, value=3, step=1, label="Citations (top-k)")

    run_btn = gr.Button("Ask")
    ans_md  = gr.Markdown(label="Answer")
    cites   = gr.Dataframe(headers=["rank","faiss_dist","snippet"], datatype=["number","number","str"],
                           row_count=(0,"dynamic"), label="Retrieved contexts")
    raw_json = gr.JSON(label="Debug / raw response")

    def _startup():
        load_artifacts()
        return "Ready."

    status = gr.Markdown()
    demo.load(_startup, inputs=None, outputs=status)
    run_btn.click(answer, [q, k], [ans_md, cites, raw_json])

if __name__ == "__main__":
    load_artifacts()
    demo.launch()