Spaces:
Sleeping
Sleeping
| import os, glob, pickle, gc | |
| from typing import List, Dict, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| from tqdm import tqdm | |
| from pypdf import PdfReader | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| # --------- CONFIG ---------- | |
| EMB_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" # 384-dim, fast on CPU | |
| LLM_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # small chat model | |
| CHUNK_SIZE = 700 | |
| CHUNK_OVERLAP = 120 | |
| TOP_K = 4 | |
| # Cached globals (Space stays warm between requests) | |
| _emb_model = None | |
| _llm_tokenizer = None | |
| _llm_model = None | |
| _faiss_index = None | |
| _meta = None | |
| # --------- HELPERS ---------- | |
| def _load_pdf(path: str) -> str: | |
| text = [] | |
| try: | |
| pdf = PdfReader(path) | |
| for p in pdf.pages: | |
| text.append(p.extract_text() or "") | |
| except Exception as e: | |
| print(f"[WARN] PDF read failed for {path}: {e}") | |
| return "\n".join(text) | |
| def _load_txt(path: str) -> str: | |
| with open(path, "r", encoding="utf-8", errors="ignore") as f: | |
| return f.read() | |
| def _chunk_text(text: str, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) -> List[str]: | |
| words = text.split() | |
| chunks, i = [], 0 | |
| step = max(1, chunk_size - chunk_overlap) | |
| while i < len(words): | |
| chunks.append(" ".join(words[i:i+chunk_size])) | |
| i += step | |
| return chunks | |
| def _ensure_models(): | |
| global _emb_model, _llm_tokenizer, _llm_model | |
| if _emb_model is None: | |
| _emb_model = SentenceTransformer(EMB_MODEL_ID) | |
| if _llm_model is None or _llm_tokenizer is None: | |
| _llm_tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID) | |
| _llm_model = AutoModelForCausalLM.from_pretrained( | |
| LLM_MODEL_ID, torch_dtype=torch.float32, device_map="cpu" | |
| ) | |
| def _reset_index(): | |
| global _faiss_index, _meta | |
| _faiss_index = None | |
| _meta = None | |
| gc.collect() | |
| def _build_index_from_files(files: List[str]) -> Tuple[int, int]: | |
| """ | |
| Build FAISS from uploaded files. Returns (#files, #chunks) | |
| """ | |
| global _faiss_index, _meta | |
| _ensure_models() | |
| docs = [] | |
| for path in files: | |
| lower = path.lower() | |
| if lower.endswith(".pdf"): | |
| txt = _load_pdf(path) | |
| elif lower.endswith((".txt", ".md")): | |
| txt = _load_txt(path) | |
| else: | |
| continue | |
| if txt.strip(): | |
| docs.append({"source": os.path.basename(path), "text": txt}) | |
| dataset = [] | |
| for d in docs: | |
| for ch in _chunk_text(d["text"]): | |
| dataset.append({"source": d["source"], "chunk": ch}) | |
| if not dataset: | |
| _reset_index() | |
| return (0, 0) | |
| # embeddings | |
| texts = [row["chunk"] for row in dataset] | |
| embs = [] | |
| for t in tqdm(texts, desc="Embedding"): | |
| embs.append(_emb_model.encode(t, show_progress_bar=False, normalize_embeddings=True)) | |
| embs = np.vstack(embs).astype("float32") | |
| index = faiss.IndexFlatIP(embs.shape[1]) # cosine via normalized vectors | |
| index.add(embs) | |
| _faiss_index = index | |
| _meta = dataset | |
| return (len(docs), len(dataset)) | |
| def _retrieve(query: str, k=TOP_K) -> List[Dict]: | |
| q = _emb_model.encode(query, normalize_embeddings=True).astype("float32") | |
| D, I = _faiss_index.search(np.expand_dims(q, 0), k) | |
| results = [] | |
| for score, idx in zip(D[0], I[0]): | |
| row = _meta[idx] | |
| results.append({"score": float(score), "source": row["source"], "text": row["chunk"]}) | |
| return results | |
| def _build_prompt(question: str, ctxs: List[Dict]) -> str: | |
| context_block = "\n\n---\n".join( | |
| [f"[{i+1}] Source: {c['source']}\n{c['text']}" for i, c in enumerate(ctxs)] | |
| ) | |
| system_rules = ( | |
| "You are a careful assistant. Answer ONLY using the provided context. " | |
| "If the answer is not in the context, say you don't know." | |
| ) | |
| user_block = ( | |
| f"Question: {question}\n\n" | |
| f"Context (use strictly):\n{context_block}\n\n" | |
| "Answer:" | |
| ) | |
| return f"<|system|>\n{system_rules}\n<|user|>\n{user_block}\n<|assistant|>\n" | |
| def _generate_answer(question: str, ctxs: List[Dict], max_new_tokens=220) -> str: | |
| inputs = _llm_tokenizer(_build_prompt(question, ctxs), return_tensors="pt") | |
| with torch.no_grad(): | |
| out = _llm_model.generate( | |
| **inputs, max_new_tokens=max_new_tokens, temperature=0.2, do_sample=False | |
| ) | |
| text = _llm_tokenizer.decode(out[0], skip_special_tokens=True) | |
| return text.split("<|assistant|>")[-1].strip() | |
| # --------- GRADIO LOGIC ---------- | |
| def init_with_samples(): | |
| """ | |
| Optional: build an index from bundled sample docs on startup. | |
| You can put .txt in a local /docs folder if you like. | |
| """ | |
| sample_dir = "docs" | |
| if os.path.isdir(sample_dir): | |
| files = [p for p in glob.glob(os.path.join(sample_dir, "*")) if os.path.isfile(p)] | |
| if files: | |
| nfiles, nchunks = _build_index_from_files(files) | |
| return f"Initialized with {nfiles} sample files → {nchunks} chunks." | |
| return "No sample docs bundled. Upload your own to get started." | |
| def upload_and_index(files): | |
| if not files: | |
| _reset_index() | |
| return "No files uploaded. Index cleared." | |
| paths = [f.name for f in files] | |
| nfiles, nchunks = _build_index_from_files(paths) | |
| return f"Indexed {nfiles} files → {nchunks} chunks. (Embedding dim=384)" | |
| def ask_question(history, question): | |
| if _faiss_index is None or _meta is None: | |
| return history + [[question, "No index yet. Upload documents first or add sample docs."]] | |
| ctxs = _retrieve(question, k=TOP_K) | |
| ans = _generate_answer(question, ctxs) | |
| # add simple citations footer | |
| cites = " ".join(f"[{i+1}:{c['source']}]" for i, c in enumerate(ctxs)) | |
| final = f"{ans}\n\nSources: {cites}" | |
| return history + [[question, final]] | |
| def clear_index(): | |
| _reset_index() | |
| return "Index cleared." | |
| with gr.Blocks(title="RAG: Chat with Your Docs (CPU)") as demo: | |
| gr.Markdown("# 🔎 Retrieval-Augmented Generation (CPU)\nUpload PDFs or text notes, then ask questions.") | |
| status = gr.Markdown(init_with_samples()) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| file_u = gr.File(label="Upload PDFs or .txt", file_count="multiple", file_types=[".pdf",".txt",".md"]) | |
| build_btn = gr.Button("Build / Rebuild Index") | |
| clear_btn = gr.Button("Clear Index") | |
| build_out = gr.Markdown() | |
| clear_out = gr.Markdown() | |
| with gr.Column(scale=2): | |
| chat = gr.Chatbot(height=420) | |
| q = gr.Textbox(label="Ask a question") | |
| ask_btn = gr.Button("Ask") | |
| build_btn.click(upload_and_index, inputs=file_u, outputs=build_out) | |
| clear_btn.click(lambda: clear_index(), outputs=clear_out) | |
| ask_btn.click(ask_question, inputs=[chat, q], outputs=chat) | |
| q.submit(ask_question, inputs=[chat, q], outputs=chat) | |
| if __name__ == "__main__": | |
| demo.launch() | |