import os, pickle, json, gradio as gr import numpy as np, faiss from sentence_transformers import SentenceTransformer # ---------- Optional LLM (OpenAI) ---------- OPENAI_API_KEY = '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' # add in Space -> Settings -> Secrets USE_OPENAI = bool(OPENAI_API_KEY) if USE_OPENAI: try: from openai import OpenAI oai = OpenAI(api_key=OPENAI_API_KEY) OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") except Exception as e: print("[RAG] OpenAI not available:", e) USE_OPENAI = False # ---------- Artifacts you already have ---------- FAISS_PATH = os.getenv("FAISS_PATH", "squad_v2.faiss") META_PATH = os.getenv("META_PATH", "squad_v2_meta.pkl") CACHE = { "index": None, "contexts": None, "encoder": None, "model_name": None } def _coerce_text_list(x): """Accepts list[str] or list[dict]; extracts text nicely.""" out = [] if isinstance(x, list): for it in x: if isinstance(it, str): out.append(it) elif isinstance(it, dict): # common keys people use text = it.get("text") or it.get("content") or it.get("ctx") or "" if text: out.append(text) return out def load_artifacts(): if CACHE["index"] is not None: return # 1) FAISS if not os.path.exists(FAISS_PATH): raise FileNotFoundError(f"Missing FAISS index: {FAISS_PATH}") index = faiss.read_index(FAISS_PATH) # 2) META if not os.path.exists(META_PATH): raise FileNotFoundError(f"Missing meta file: {META_PATH}") with open(META_PATH, "rb") as f: meta = pickle.load(f) # parse meta model_name = "all-MiniLM-L6-v2" contexts = None if isinstance(meta, dict): # common keys model_name = meta.get("model") or meta.get("encoder") or model_name contexts = ( meta.get("contexts") or meta.get("texts") or meta.get("documents") or meta.get("corpus") ) else: # meta is just a list of contexts contexts = meta # normalize contexts contexts = _coerce_text_list(contexts) if contexts is not None else [] if not contexts: raise ValueError("No contexts found in meta; expected a list of texts.") # Align lengths (safeguard) ntotal = index.ntotal if ntotal != len(contexts): m = min(ntotal, len(contexts)) print(f"[RAG] WARNING: index.ntotal({ntotal}) != contexts({len(contexts)}). Trimming to {m}.") # We can’t resize FAISS easily here; instead trim contexts so we never index out of range. contexts = contexts[:m] # 3) load encoder (lazy; we instantiate now to avoid first-click delay) encoder = SentenceTransformer(model_name) CACHE.update(index=index, contexts=contexts, encoder=encoder, model_name=model_name) print(f"[RAG] Loaded index={FAISS_PATH} (ntotal={CACHE['index'].ntotal}), " f"contexts={len(CACHE['contexts'])}, model={CACHE['model_name']}") def _retrieve(question: str, k: int): # encode query; FAISS expects float32 q_emb = CACHE["encoder"].encode([question]).astype("float32") D, I = CACHE["index"].search(q_emb, int(k)) idxs = I[0].tolist() dists = D[0].tolist() # guard for any out-of-range due to mismatched sizes max_ok = len(CACHE["contexts"]) - 1 pairs = [] for j, dist in zip(idxs, dists): if 0 <= j <= max_ok: pairs.append((j, dist, CACHE["contexts"][j])) return pairs def _build_prompt(question: str, pairs): chunks = [] for i, (_, _d, ctx) in enumerate(pairs, start=1): # keep prompt size reasonable ctx_short = ctx.strip() if len(ctx_short) > 1200: ctx_short = ctx_short[:1200] + "..." chunks.append(f"[Source {i}] {ctx_short}") context_block = "\n\n".join(chunks) if chunks else "(no context)" prompt = f"""Answer strictly from the context below. If not answerable, say so. Include [Source X] citations in your answer. Context: {context_block} Question: {question} Answer:""" return prompt def answer(question: str, k: int): if not question.strip(): return "Please enter a question.", [], None pairs = _retrieve(question, k) if not pairs: return "No results found in the index.", [], None # Build citations list for UI citations = [{"rank": i+1, "faiss_dist": round(d, 4), "snippet": ctx[:240] + ("..." if len(ctx) > 240 else "")} for i, (_idx, d, ctx) in enumerate(pairs)] if USE_OPENAI: prompt = _build_prompt(question, pairs) try: resp = oai.chat.completions.create( model=OPENAI_MODEL, messages=[{"role":"user","content":prompt}], temperature=0.2 ) ans = resp.choices[0].message.content except Exception as e: ans = f"LLM call failed: {e}\n\nTop results shown below." else: # Fallback: show top-1 context as the “answer” ans = ("(No OPENAI_API_KEY set — showing most relevant context instead.)\n\n" + pairs[0][2][:1200]) # simple JSON for debugging/export raw = { "k": int(k), "answer": ans, "citations": citations } return ans, citations, json.dumps(raw, indent=2) # ---------- UI ---------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("## Nyxion Labs · Grounded Q&A (SQuAD v2, FAISS)") with gr.Row(): q = gr.Textbox(label="Ask a question", placeholder="e.g., What is the capital of France?") k = gr.Slider(1, 10, value=3, step=1, label="Citations (top-k)") run_btn = gr.Button("Ask") ans_md = gr.Markdown(label="Answer") cites = gr.Dataframe(headers=["rank","faiss_dist","snippet"], datatype=["number","number","str"], row_count=(0,"dynamic"), label="Retrieved contexts") raw_json = gr.JSON(label="Debug / raw response") def _startup(): load_artifacts() return "Ready." status = gr.Markdown() demo.load(_startup, inputs=None, outputs=status) run_btn.click(answer, [q, k], [ans_md, cites, raw_json]) if __name__ == "__main__": load_artifacts() demo.launch()