nyxion-rag / app.py
nyxionlabs's picture
Upload 3 files
e2ffd2f verified
raw
history blame
6.67 kB
import os, pickle, json, gradio as gr
import numpy as np, faiss
from sentence_transformers import SentenceTransformer
# ---------- Optional LLM (OpenAI) ----------
OPENAI_API_KEY = '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' # add in Space -> Settings -> Secrets
USE_OPENAI = bool(OPENAI_API_KEY)
if USE_OPENAI:
try:
from openai import OpenAI
oai = OpenAI(api_key=OPENAI_API_KEY)
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
except Exception as e:
print("[RAG] OpenAI not available:", e)
USE_OPENAI = False
# ---------- Artifacts you already have ----------
FAISS_PATH = os.getenv("FAISS_PATH", "squad_v2.faiss")
META_PATH = os.getenv("META_PATH", "squad_v2_meta.pkl")
CACHE = {
"index": None,
"contexts": None,
"encoder": None,
"model_name": None
}
def _coerce_text_list(x):
"""Accepts list[str] or list[dict]; extracts text nicely."""
out = []
if isinstance(x, list):
for it in x:
if isinstance(it, str):
out.append(it)
elif isinstance(it, dict):
# common keys people use
text = it.get("text") or it.get("content") or it.get("ctx") or ""
if text:
out.append(text)
return out
def load_artifacts():
if CACHE["index"] is not None:
return
# 1) FAISS
if not os.path.exists(FAISS_PATH):
raise FileNotFoundError(f"Missing FAISS index: {FAISS_PATH}")
index = faiss.read_index(FAISS_PATH)
# 2) META
if not os.path.exists(META_PATH):
raise FileNotFoundError(f"Missing meta file: {META_PATH}")
with open(META_PATH, "rb") as f:
meta = pickle.load(f)
# parse meta
model_name = "all-MiniLM-L6-v2"
contexts = None
if isinstance(meta, dict):
# common keys
model_name = meta.get("model") or meta.get("encoder") or model_name
contexts = (
meta.get("contexts")
or meta.get("texts")
or meta.get("documents")
or meta.get("corpus")
)
else:
# meta is just a list of contexts
contexts = meta
# normalize contexts
contexts = _coerce_text_list(contexts) if contexts is not None else []
if not contexts:
raise ValueError("No contexts found in meta; expected a list of texts.")
# Align lengths (safeguard)
ntotal = index.ntotal
if ntotal != len(contexts):
m = min(ntotal, len(contexts))
print(f"[RAG] WARNING: index.ntotal({ntotal}) != contexts({len(contexts)}). Trimming to {m}.")
# We can’t resize FAISS easily here; instead trim contexts so we never index out of range.
contexts = contexts[:m]
# 3) load encoder (lazy; we instantiate now to avoid first-click delay)
encoder = SentenceTransformer(model_name)
CACHE.update(index=index, contexts=contexts, encoder=encoder, model_name=model_name)
print(f"[RAG] Loaded index={FAISS_PATH} (ntotal={CACHE['index'].ntotal}), "
f"contexts={len(CACHE['contexts'])}, model={CACHE['model_name']}")
def _retrieve(question: str, k: int):
# encode query; FAISS expects float32
q_emb = CACHE["encoder"].encode([question]).astype("float32")
D, I = CACHE["index"].search(q_emb, int(k))
idxs = I[0].tolist()
dists = D[0].tolist()
# guard for any out-of-range due to mismatched sizes
max_ok = len(CACHE["contexts"]) - 1
pairs = []
for j, dist in zip(idxs, dists):
if 0 <= j <= max_ok:
pairs.append((j, dist, CACHE["contexts"][j]))
return pairs
def _build_prompt(question: str, pairs):
chunks = []
for i, (_, _d, ctx) in enumerate(pairs, start=1):
# keep prompt size reasonable
ctx_short = ctx.strip()
if len(ctx_short) > 1200:
ctx_short = ctx_short[:1200] + "..."
chunks.append(f"[Source {i}] {ctx_short}")
context_block = "\n\n".join(chunks) if chunks else "(no context)"
prompt = f"""Answer strictly from the context below. If not answerable, say so.
Include [Source X] citations in your answer.
Context:
{context_block}
Question: {question}
Answer:"""
return prompt
def answer(question: str, k: int):
if not question.strip():
return "Please enter a question.", [], None
pairs = _retrieve(question, k)
if not pairs:
return "No results found in the index.", [], None
# Build citations list for UI
citations = [{"rank": i+1, "faiss_dist": round(d, 4), "snippet": ctx[:240] + ("..." if len(ctx) > 240 else "")}
for i, (_idx, d, ctx) in enumerate(pairs)]
if USE_OPENAI:
prompt = _build_prompt(question, pairs)
try:
resp = oai.chat.completions.create(
model=OPENAI_MODEL,
messages=[{"role":"user","content":prompt}],
temperature=0.2
)
ans = resp.choices[0].message.content
except Exception as e:
ans = f"LLM call failed: {e}\n\nTop results shown below."
else:
# Fallback: show top-1 context as the “answer”
ans = ("(No OPENAI_API_KEY set — showing most relevant context instead.)\n\n" +
pairs[0][2][:1200])
# simple JSON for debugging/export
raw = {
"k": int(k),
"answer": ans,
"citations": citations
}
return ans, citations, json.dumps(raw, indent=2)
# ---------- UI ----------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Nyxion Labs · Grounded Q&A (SQuAD v2, FAISS)")
with gr.Row():
q = gr.Textbox(label="Ask a question", placeholder="e.g., What is the capital of France?")
k = gr.Slider(1, 10, value=3, step=1, label="Citations (top-k)")
run_btn = gr.Button("Ask")
ans_md = gr.Markdown(label="Answer")
cites = gr.Dataframe(headers=["rank","faiss_dist","snippet"], datatype=["number","number","str"],
row_count=(0,"dynamic"), label="Retrieved contexts")
raw_json = gr.JSON(label="Debug / raw response")
def _startup():
load_artifacts()
return "Ready."
status = gr.Markdown()
demo.load(_startup, inputs=None, outputs=status)
run_btn.click(answer, [q, k], [ans_md, cites, raw_json])
if __name__ == "__main__":
load_artifacts()
demo.launch()