Spaces:
Sleeping
Sleeping
| import os, gradio as gr | |
| import numpy as np | |
| import faiss | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| # Optional LLM step (still works without it) | |
| OPENAI_API_KEY = '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' | |
| USE_OPENAI = bool(OPENAI_API_KEY) | |
| print(f"[RAG] OPENAI_API_KEY found: {bool(OPENAI_API_KEY)}") | |
| if USE_OPENAI: | |
| try: | |
| from openai import OpenAI | |
| oai = OpenAI(api_key=OPENAI_API_KEY) | |
| OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") | |
| print(f"[RAG] OpenAI initialized with model: {OPENAI_MODEL}") | |
| except Exception as e: | |
| print("[RAG] OpenAI import failed:", e) | |
| USE_OPENAI = False | |
| else: | |
| print("[RAG] No OpenAI API key detected. Set OPENAI_API_KEY in Space Settings.") | |
| # Tunables (can override in Space β Settings β Variables) | |
| MODEL_NAME = os.getenv("EMBED_MODEL", "all-MiniLM-L6-v2") | |
| SQUAD_SLICE = os.getenv("SQUAD_SLICE", "2000") # e.g. "1000" or "2%" also works | |
| MAX_CTX_CHAR = int(os.getenv("MAX_CTX_CHAR", "1200")) | |
| STATE = {"index": None, "contexts": None, "encoder": None} | |
| def _fallback_corpus(): | |
| # Tiny backup if SQuAD fails to download (offline) | |
| return [ | |
| "Paris is the capital and most populous city of France.", | |
| "The Pacific Ocean is the largest and deepest of Earth's oceanic divisions.", | |
| "The human heart has four chambers: two atria and two ventricles.", | |
| "Mount Everest is Earth's highest mountain above sea level.", | |
| "Photosynthesis converts light energy into chemical energy in plants.", | |
| "The Nile is a major north-flowing river in northeastern Africa.", | |
| "Berlin is the capital and largest city of Germany.", | |
| "Tokyo is the capital of Japan and one of the world's most populous cities.", | |
| "The Great Wall of China is one of the most famous landmarks in the world.", | |
| "DNA contains the genetic instructions for all living organisms.", | |
| ] | |
| def build_index(): | |
| """ | |
| Build a small FAISS index directly from SQuAD v2 (no uploaded files). | |
| """ | |
| if STATE["index"] is not None: | |
| return "Index already loaded." | |
| try: | |
| print(f"[RAG] Loading SQuAD v2 sample: train[:{SQUAD_SLICE}] β¦") | |
| ds = load_dataset("rajpurkar/squad_v2", split=f"train[:{SQUAD_SLICE}]") | |
| seen, contexts = set(), [] | |
| for row in ds: | |
| c = row["context"] | |
| if c not in seen: | |
| seen.add(c); contexts.append(c) | |
| if len(contexts) == 0: | |
| raise RuntimeError("Empty SQuAD slice.") | |
| except Exception as e: | |
| print("[RAG] Could not load SQuAD, using fallback corpus:", e) | |
| contexts = _fallback_corpus() | |
| print(f"[RAG] Encoding {len(contexts)} contexts with {MODEL_NAME} β¦") | |
| encoder = SentenceTransformer(MODEL_NAME) | |
| # batch encode β float32 for faiss | |
| emb = encoder.encode(contexts, show_progress_bar=True, batch_size=128).astype("float32") | |
| # Simple exact search index (robust & dependency-free) | |
| index = faiss.IndexFlatL2(emb.shape[1]) | |
| index.add(emb) | |
| STATE.update(index=index, contexts=contexts, encoder=encoder) | |
| print(f"[RAG] Ready: ntotal={STATE['index'].ntotal}") | |
| return f"Built FAISS index with {len(contexts)} contexts." | |
| def retrieve(question: str, k: int): | |
| q = STATE["encoder"].encode([question]).astype("float32") | |
| D, I = STATE["index"].search(q, int(k)) | |
| pairs = [] | |
| for rank, (i, dist) in enumerate(zip(I[0], D[0]), start=1): | |
| if 0 <= i < len(STATE["contexts"]): | |
| ctx = STATE["contexts"][i] | |
| pairs.append({ | |
| "rank": rank, | |
| "faiss_dist": float(dist), | |
| "snippet": ctx[:240] + ("β¦" if len(ctx) > 240 else ""), | |
| "full": ctx | |
| }) | |
| return pairs | |
| def build_prompt(question: str, pairs): | |
| blocks = [] | |
| for j, p in enumerate(pairs, start=1): | |
| ctx = p["full"].strip() | |
| if len(ctx) > MAX_CTX_CHAR: | |
| ctx = ctx[:MAX_CTX_CHAR] + "β¦" | |
| blocks.append(f"[Source {j}] {ctx}") | |
| context_block = "\n\n".join(blocks) if blocks else "(no context)" | |
| return f"""Answer strictly from the context below. If not answerable, say so. | |
| Include [Source X] citations in your answer. | |
| Context: | |
| {context_block} | |
| Question: {question} | |
| Answer:""" | |
| def answer(question: str, k: int): | |
| if STATE["index"] is None: | |
| build_index() | |
| if not question.strip(): | |
| return "Please enter a question.", [], {"status": "idle", "openai_enabled": USE_OPENAI} | |
| pairs = retrieve(question, k) | |
| if not pairs: | |
| return "No results in index.", [], {"status": "empty", "openai_enabled": USE_OPENAI} | |
| cites = [{"rank": p["rank"], "faiss_dist": round(p["faiss_dist"], 4), "snippet": p["snippet"]} for p in pairs] | |
| if USE_OPENAI: | |
| prompt = build_prompt(question, pairs) | |
| try: | |
| print(f"[RAG] Calling OpenAI with model: {OPENAI_MODEL}") | |
| resp = oai.chat.completions.create( | |
| model=OPENAI_MODEL, | |
| messages=[{"role": "user", "content": prompt}], | |
| temperature=0.2, | |
| max_tokens=500 | |
| ) | |
| ans = resp.choices[0].message.content | |
| print(f"[RAG] OpenAI response received successfully") | |
| except Exception as e: | |
| print(f"[RAG] LLM call failed: {e}") | |
| ans = f"β LLM call failed: {e}\n\n**Top result shown below:**\n\n{pairs[0]['full'][:MAX_CTX_CHAR]}" | |
| else: | |
| ans = ("β οΈ **No OPENAI_API_KEY set** β Add it in Space Settings β Repository secrets\n\n" | |
| "**Showing most relevant context instead:**\n\n" | |
| + pairs[0]["full"][:MAX_CTX_CHAR]) | |
| return ans, cites, { | |
| "status": "ok", | |
| "ntotal": STATE['index'].ntotal, | |
| "model": MODEL_NAME, | |
| "openai_enabled": USE_OPENAI, | |
| "openai_model": OPENAI_MODEL if USE_OPENAI else None | |
| } | |
| # ------------------- UI ------------------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| ## Nyxion Labs Β· Grounded Q&A (RAG Demo) | |
| Ask questions and get answers grounded in context with citations. | |
| """) | |
| if not USE_OPENAI: | |
| gr.Markdown(""" | |
| β οΈ **OpenAI API Key Not Detected** | |
| To enable AI-generated answers: | |
| 1. Go to Space Settings | |
| 2. Add `OPENAI_API_KEY` as a repository secret | |
| 3. Restart the Space | |
| Currently showing raw context retrieval only. | |
| """) | |
| with gr.Row(): | |
| q = gr.Textbox( | |
| label="Ask a question", | |
| placeholder="e.g., What is the capital of Germany?", | |
| lines=2 | |
| ) | |
| k = gr.Slider(1, 10, value=3, step=1, label="Number of Citations (top-k)") | |
| btn = gr.Button("π Ask", variant="primary") | |
| ans = gr.Markdown(label="Answer") | |
| cites = gr.Dataframe( | |
| headers=["rank", "faiss_dist", "snippet"], | |
| datatype=["number","number","str"], | |
| row_count=(0, "dynamic"), | |
| label="Retrieved Contexts" | |
| ) | |
| meta = gr.JSON(label="System Status") | |
| def _startup(): | |
| try: | |
| msg = build_index() | |
| return { | |
| "status": msg, | |
| "openai_enabled": USE_OPENAI, | |
| "openai_model": OPENAI_MODEL if USE_OPENAI else None, | |
| "embed_model": MODEL_NAME | |
| } | |
| except Exception as e: | |
| return {"status": f"Startup build failed: {e}", "openai_enabled": False} | |
| demo.load(_startup, inputs=None, outputs=meta) | |
| btn.click(answer, [q, k], [ans, cites, meta]) | |
| q.submit(answer, [q, k], [ans, cites, meta]) # Allow Enter key to submit | |
| if __name__ == "__main__": | |
| build_index() | |
| demo.launch() |