Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import gradio as gr | |
| import numpy as np | |
| import faiss | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer | |
| # ---- CONFIG ---- | |
| EMBED_MODEL_ID = os.getenv("EMBED_MODEL_ID", "guipa01/best-telos-model") | |
| JSONL_PATH = os.getenv("JSONL_PATH", "docs.jsonl") | |
| # Put cache somewhere that survives restarts if you enabled Persistent Storage. | |
| # If you did NOT enable it, use ".cache" (still helps within the same container life). | |
| CACHE_DIR = os.getenv("CACHE_DIR", "./.cache") # change to ".cache" if no /data | |
| os.makedirs(CACHE_DIR, exist_ok=True) | |
| INDEX_PATH = os.path.join(CACHE_DIR, "docs.faiss") | |
| META_PATH = os.path.join(CACHE_DIR, "docs_meta.npz") | |
| TYPES_PATH = os.path.join(CACHE_DIR, "type_choices.json") | |
| embedder = None | |
| index = None | |
| ids = None | |
| types = None | |
| docs = None | |
| type_choices = None | |
| def load_docs(): | |
| ds = load_dataset("json", data_files=JSONL_PATH, split="train") | |
| return ds["id"], ds["type"], ds["text"] | |
| def ensure_ready(): | |
| """ | |
| Lazily initializes model + index. | |
| Loads cached FAISS index if available; otherwise builds and caches. | |
| """ | |
| global embedder, index, ids, types, docs, type_choices | |
| if index is not None and embedder is not None: | |
| return | |
| if embedder is None: | |
| embedder = SentenceTransformer(EMBED_MODEL_ID) | |
| # Try fast path: load cached index + metadata | |
| if os.path.exists(INDEX_PATH) and os.path.exists(META_PATH) and os.path.exists(TYPES_PATH): | |
| index = faiss.read_index(INDEX_PATH) | |
| meta = np.load(META_PATH, allow_pickle=True) | |
| ids = meta["ids"] | |
| types = meta["types"] | |
| docs = meta["docs"] | |
| with open(TYPES_PATH, "r", encoding="utf-8") as f: | |
| type_choices = json.load(f) | |
| return | |
| # Slow path: build everything once | |
| ids, types, docs = load_docs() | |
| doc_emb = embedder.encode( | |
| docs, | |
| batch_size=32, | |
| normalize_embeddings=True, | |
| show_progress_bar=False | |
| ).astype("float32") | |
| dim = doc_emb.shape[1] | |
| index = faiss.IndexFlatIP(dim) | |
| index.add(doc_emb) | |
| type_choices = ["(any)"] + sorted(set(types)) | |
| # Cache to disk | |
| faiss.write_index(index, INDEX_PATH) | |
| np.savez_compressed(META_PATH, ids=np.array(ids), types=np.array(types), docs=np.array(docs)) | |
| with open(TYPES_PATH, "w", encoding="utf-8") as f: | |
| json.dump(type_choices, f) | |
| def retrieve(query: str, top_k: int = 5, type_filter: str = "(any)"): | |
| if not query.strip(): | |
| return "Type a question above." | |
| ensure_ready() | |
| q_emb = embedder.encode([query], normalize_embeddings=True).astype("float32") | |
| scores, idxs = index.search(q_emb, k=min(50, len(docs))) | |
| results = [] | |
| for score, i in zip(scores[0], idxs[0]): | |
| if i < 0: | |
| continue | |
| if type_filter != "(any)" and types[i] != type_filter: | |
| continue | |
| results.append((float(score), str(ids[i]), str(types[i]), str(docs[i]))) | |
| if len(results) >= top_k: | |
| break | |
| if not results: | |
| return "No matches (try removing the type filter or increasing Top K)." | |
| out = [] | |
| for rank, (score, doc_id, doc_type, doc_text) in enumerate(results, 1): | |
| out.append( | |
| f"### {rank}. score={score:.4f} | id={doc_id} | type={doc_type}\n{doc_text}\n" | |
| ) | |
| return "\n".join(out) | |
| def get_type_choices(): | |
| # Make dropdown populate without forcing a full build at import time | |
| try: | |
| ensure_ready() | |
| return type_choices | |
| except Exception: | |
| return ["(any)"] | |
| demo = gr.Interface( | |
| fn=retrieve, | |
| inputs=[ | |
| gr.Textbox(label="Question", placeholder="Ask something..."), | |
| gr.Slider(1, 20, value=5, step=1, label="Top K"), | |
| gr.Dropdown(choices=get_type_choices(), value="(any)", label="Filter by type"), | |
| ], | |
| outputs=gr.Markdown(label="Retrieved documents"), | |
| title="Embedding Retrieval Demo", | |
| description="Query → embed → cosine similarity search over your JSONL docs.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |