Spaces:
Sleeping
Sleeping
File size: 6,674 Bytes
e2ffd2f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import os, pickle, json, gradio as gr
import numpy as np, faiss
from sentence_transformers import SentenceTransformer
# ---------- Optional LLM (OpenAI) ----------
OPENAI_API_KEY = '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' # add in Space -> Settings -> Secrets
USE_OPENAI = bool(OPENAI_API_KEY)
if USE_OPENAI:
try:
from openai import OpenAI
oai = OpenAI(api_key=OPENAI_API_KEY)
OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
except Exception as e:
print("[RAG] OpenAI not available:", e)
USE_OPENAI = False
# ---------- Artifacts you already have ----------
FAISS_PATH = os.getenv("FAISS_PATH", "squad_v2.faiss")
META_PATH = os.getenv("META_PATH", "squad_v2_meta.pkl")
CACHE = {
"index": None,
"contexts": None,
"encoder": None,
"model_name": None
}
def _coerce_text_list(x):
"""Accepts list[str] or list[dict]; extracts text nicely."""
out = []
if isinstance(x, list):
for it in x:
if isinstance(it, str):
out.append(it)
elif isinstance(it, dict):
# common keys people use
text = it.get("text") or it.get("content") or it.get("ctx") or ""
if text:
out.append(text)
return out
def load_artifacts():
if CACHE["index"] is not None:
return
# 1) FAISS
if not os.path.exists(FAISS_PATH):
raise FileNotFoundError(f"Missing FAISS index: {FAISS_PATH}")
index = faiss.read_index(FAISS_PATH)
# 2) META
if not os.path.exists(META_PATH):
raise FileNotFoundError(f"Missing meta file: {META_PATH}")
with open(META_PATH, "rb") as f:
meta = pickle.load(f)
# parse meta
model_name = "all-MiniLM-L6-v2"
contexts = None
if isinstance(meta, dict):
# common keys
model_name = meta.get("model") or meta.get("encoder") or model_name
contexts = (
meta.get("contexts")
or meta.get("texts")
or meta.get("documents")
or meta.get("corpus")
)
else:
# meta is just a list of contexts
contexts = meta
# normalize contexts
contexts = _coerce_text_list(contexts) if contexts is not None else []
if not contexts:
raise ValueError("No contexts found in meta; expected a list of texts.")
# Align lengths (safeguard)
ntotal = index.ntotal
if ntotal != len(contexts):
m = min(ntotal, len(contexts))
print(f"[RAG] WARNING: index.ntotal({ntotal}) != contexts({len(contexts)}). Trimming to {m}.")
# We can’t resize FAISS easily here; instead trim contexts so we never index out of range.
contexts = contexts[:m]
# 3) load encoder (lazy; we instantiate now to avoid first-click delay)
encoder = SentenceTransformer(model_name)
CACHE.update(index=index, contexts=contexts, encoder=encoder, model_name=model_name)
print(f"[RAG] Loaded index={FAISS_PATH} (ntotal={CACHE['index'].ntotal}), "
f"contexts={len(CACHE['contexts'])}, model={CACHE['model_name']}")
def _retrieve(question: str, k: int):
# encode query; FAISS expects float32
q_emb = CACHE["encoder"].encode([question]).astype("float32")
D, I = CACHE["index"].search(q_emb, int(k))
idxs = I[0].tolist()
dists = D[0].tolist()
# guard for any out-of-range due to mismatched sizes
max_ok = len(CACHE["contexts"]) - 1
pairs = []
for j, dist in zip(idxs, dists):
if 0 <= j <= max_ok:
pairs.append((j, dist, CACHE["contexts"][j]))
return pairs
def _build_prompt(question: str, pairs):
chunks = []
for i, (_, _d, ctx) in enumerate(pairs, start=1):
# keep prompt size reasonable
ctx_short = ctx.strip()
if len(ctx_short) > 1200:
ctx_short = ctx_short[:1200] + "..."
chunks.append(f"[Source {i}] {ctx_short}")
context_block = "\n\n".join(chunks) if chunks else "(no context)"
prompt = f"""Answer strictly from the context below. If not answerable, say so.
Include [Source X] citations in your answer.
Context:
{context_block}
Question: {question}
Answer:"""
return prompt
def answer(question: str, k: int):
if not question.strip():
return "Please enter a question.", [], None
pairs = _retrieve(question, k)
if not pairs:
return "No results found in the index.", [], None
# Build citations list for UI
citations = [{"rank": i+1, "faiss_dist": round(d, 4), "snippet": ctx[:240] + ("..." if len(ctx) > 240 else "")}
for i, (_idx, d, ctx) in enumerate(pairs)]
if USE_OPENAI:
prompt = _build_prompt(question, pairs)
try:
resp = oai.chat.completions.create(
model=OPENAI_MODEL,
messages=[{"role":"user","content":prompt}],
temperature=0.2
)
ans = resp.choices[0].message.content
except Exception as e:
ans = f"LLM call failed: {e}\n\nTop results shown below."
else:
# Fallback: show top-1 context as the “answer”
ans = ("(No OPENAI_API_KEY set — showing most relevant context instead.)\n\n" +
pairs[0][2][:1200])
# simple JSON for debugging/export
raw = {
"k": int(k),
"answer": ans,
"citations": citations
}
return ans, citations, json.dumps(raw, indent=2)
# ---------- UI ----------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("## Nyxion Labs · Grounded Q&A (SQuAD v2, FAISS)")
with gr.Row():
q = gr.Textbox(label="Ask a question", placeholder="e.g., What is the capital of France?")
k = gr.Slider(1, 10, value=3, step=1, label="Citations (top-k)")
run_btn = gr.Button("Ask")
ans_md = gr.Markdown(label="Answer")
cites = gr.Dataframe(headers=["rank","faiss_dist","snippet"], datatype=["number","number","str"],
row_count=(0,"dynamic"), label="Retrieved contexts")
raw_json = gr.JSON(label="Debug / raw response")
def _startup():
load_artifacts()
return "Ready."
status = gr.Markdown()
demo.load(_startup, inputs=None, outputs=status)
run_btn.click(answer, [q, k], [ans_md, cites, raw_json])
if __name__ == "__main__":
load_artifacts()
demo.launch()
|