Spaces:
Sleeping
Sleeping
| import os, pickle, json, gradio as gr | |
| import numpy as np, faiss | |
| from sentence_transformers import SentenceTransformer | |
| # ---------- Optional LLM (OpenAI) ---------- | |
| OPENAI_API_KEY = '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' # add in Space -> Settings -> Secrets | |
| USE_OPENAI = bool(OPENAI_API_KEY) | |
| if USE_OPENAI: | |
| try: | |
| from openai import OpenAI | |
| oai = OpenAI(api_key=OPENAI_API_KEY) | |
| OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") | |
| except Exception as e: | |
| print("[RAG] OpenAI not available:", e) | |
| USE_OPENAI = False | |
| # ---------- Artifacts you already have ---------- | |
| FAISS_PATH = os.getenv("FAISS_PATH", "squad_v2.faiss") | |
| META_PATH = os.getenv("META_PATH", "squad_v2_meta.pkl") | |
| CACHE = { | |
| "index": None, | |
| "contexts": None, | |
| "encoder": None, | |
| "model_name": None | |
| } | |
| def _coerce_text_list(x): | |
| """Accepts list[str] or list[dict]; extracts text nicely.""" | |
| out = [] | |
| if isinstance(x, list): | |
| for it in x: | |
| if isinstance(it, str): | |
| out.append(it) | |
| elif isinstance(it, dict): | |
| # common keys people use | |
| text = it.get("text") or it.get("content") or it.get("ctx") or "" | |
| if text: | |
| out.append(text) | |
| return out | |
| def load_artifacts(): | |
| if CACHE["index"] is not None: | |
| return | |
| # 1) FAISS | |
| if not os.path.exists(FAISS_PATH): | |
| raise FileNotFoundError(f"Missing FAISS index: {FAISS_PATH}") | |
| index = faiss.read_index(FAISS_PATH) | |
| # 2) META | |
| if not os.path.exists(META_PATH): | |
| raise FileNotFoundError(f"Missing meta file: {META_PATH}") | |
| with open(META_PATH, "rb") as f: | |
| meta = pickle.load(f) | |
| # parse meta | |
| model_name = "all-MiniLM-L6-v2" | |
| contexts = None | |
| if isinstance(meta, dict): | |
| # common keys | |
| model_name = meta.get("model") or meta.get("encoder") or model_name | |
| contexts = ( | |
| meta.get("contexts") | |
| or meta.get("texts") | |
| or meta.get("documents") | |
| or meta.get("corpus") | |
| ) | |
| else: | |
| # meta is just a list of contexts | |
| contexts = meta | |
| # normalize contexts | |
| contexts = _coerce_text_list(contexts) if contexts is not None else [] | |
| if not contexts: | |
| raise ValueError("No contexts found in meta; expected a list of texts.") | |
| # Align lengths (safeguard) | |
| ntotal = index.ntotal | |
| if ntotal != len(contexts): | |
| m = min(ntotal, len(contexts)) | |
| print(f"[RAG] WARNING: index.ntotal({ntotal}) != contexts({len(contexts)}). Trimming to {m}.") | |
| # We can’t resize FAISS easily here; instead trim contexts so we never index out of range. | |
| contexts = contexts[:m] | |
| # 3) load encoder (lazy; we instantiate now to avoid first-click delay) | |
| encoder = SentenceTransformer(model_name) | |
| CACHE.update(index=index, contexts=contexts, encoder=encoder, model_name=model_name) | |
| print(f"[RAG] Loaded index={FAISS_PATH} (ntotal={CACHE['index'].ntotal}), " | |
| f"contexts={len(CACHE['contexts'])}, model={CACHE['model_name']}") | |
| def _retrieve(question: str, k: int): | |
| # encode query; FAISS expects float32 | |
| q_emb = CACHE["encoder"].encode([question]).astype("float32") | |
| D, I = CACHE["index"].search(q_emb, int(k)) | |
| idxs = I[0].tolist() | |
| dists = D[0].tolist() | |
| # guard for any out-of-range due to mismatched sizes | |
| max_ok = len(CACHE["contexts"]) - 1 | |
| pairs = [] | |
| for j, dist in zip(idxs, dists): | |
| if 0 <= j <= max_ok: | |
| pairs.append((j, dist, CACHE["contexts"][j])) | |
| return pairs | |
| def _build_prompt(question: str, pairs): | |
| chunks = [] | |
| for i, (_, _d, ctx) in enumerate(pairs, start=1): | |
| # keep prompt size reasonable | |
| ctx_short = ctx.strip() | |
| if len(ctx_short) > 1200: | |
| ctx_short = ctx_short[:1200] + "..." | |
| chunks.append(f"[Source {i}] {ctx_short}") | |
| context_block = "\n\n".join(chunks) if chunks else "(no context)" | |
| prompt = f"""Answer strictly from the context below. If not answerable, say so. | |
| Include [Source X] citations in your answer. | |
| Context: | |
| {context_block} | |
| Question: {question} | |
| Answer:""" | |
| return prompt | |
| def answer(question: str, k: int): | |
| if not question.strip(): | |
| return "Please enter a question.", [], None | |
| pairs = _retrieve(question, k) | |
| if not pairs: | |
| return "No results found in the index.", [], None | |
| # Build citations list for UI | |
| citations = [{"rank": i+1, "faiss_dist": round(d, 4), "snippet": ctx[:240] + ("..." if len(ctx) > 240 else "")} | |
| for i, (_idx, d, ctx) in enumerate(pairs)] | |
| if USE_OPENAI: | |
| prompt = _build_prompt(question, pairs) | |
| try: | |
| resp = oai.chat.completions.create( | |
| model=OPENAI_MODEL, | |
| messages=[{"role":"user","content":prompt}], | |
| temperature=0.2 | |
| ) | |
| ans = resp.choices[0].message.content | |
| except Exception as e: | |
| ans = f"LLM call failed: {e}\n\nTop results shown below." | |
| else: | |
| # Fallback: show top-1 context as the “answer” | |
| ans = ("(No OPENAI_API_KEY set — showing most relevant context instead.)\n\n" + | |
| pairs[0][2][:1200]) | |
| # simple JSON for debugging/export | |
| raw = { | |
| "k": int(k), | |
| "answer": ans, | |
| "citations": citations | |
| } | |
| return ans, citations, json.dumps(raw, indent=2) | |
| # ---------- UI ---------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("## Nyxion Labs · Grounded Q&A (SQuAD v2, FAISS)") | |
| with gr.Row(): | |
| q = gr.Textbox(label="Ask a question", placeholder="e.g., What is the capital of France?") | |
| k = gr.Slider(1, 10, value=3, step=1, label="Citations (top-k)") | |
| run_btn = gr.Button("Ask") | |
| ans_md = gr.Markdown(label="Answer") | |
| cites = gr.Dataframe(headers=["rank","faiss_dist","snippet"], datatype=["number","number","str"], | |
| row_count=(0,"dynamic"), label="Retrieved contexts") | |
| raw_json = gr.JSON(label="Debug / raw response") | |
| def _startup(): | |
| load_artifacts() | |
| return "Ready." | |
| status = gr.Markdown() | |
| demo.load(_startup, inputs=None, outputs=status) | |
| run_btn.click(answer, [q, k], [ans_md, cites, raw_json]) | |
| if __name__ == "__main__": | |
| load_artifacts() | |
| demo.launch() | |