Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import gradio as gr | |
| import chromadb | |
| import PyPDF2 | |
| from groq import Groq | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from rank_bm25 import BM25Okapi | |
| # ββ Models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| groq_client = Groq(api_key=os.getenv("GROQ_API_KEY")) | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| # ββ ChromaDB (in-memory) βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| chroma_client = chromadb.Client() | |
| collection = chroma_client.get_or_create_collection("rag_docs") | |
| # ββ Global store for BM25 ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| doc_store = [] # list of {"id": str, "text": str, "source": str} | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def chunk_text(text, chunk_size=400, overlap=80): | |
| """Split text into overlapping chunks.""" | |
| words = text.split() | |
| chunks = [] | |
| start = 0 | |
| while start < len(words): | |
| chunk = " ".join(words[start : start + chunk_size]) | |
| chunks.append(chunk) | |
| start += chunk_size - overlap | |
| return chunks | |
| def extract_text_from_pdf(file_path): | |
| text = "" | |
| with open(file_path, "rb") as f: | |
| reader = PyPDF2.PdfReader(f) | |
| for page in reader.pages: | |
| text += page.extract_text() or "" | |
| return text | |
| def extract_text_from_txt(file_path): | |
| with open(file_path, "r", encoding="utf-8", errors="ignore") as f: | |
| return f.read() | |
| # ββ Ingest βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ingest_documents(files): | |
| global doc_store, collection | |
| if not files: | |
| return "β οΈ No files uploaded." | |
| # Reset | |
| doc_store = [] | |
| chroma_client.delete_collection("rag_docs") | |
| collection = chroma_client.get_or_create_collection("rag_docs") | |
| total_chunks = 0 | |
| file_names = [] | |
| for file in files: | |
| path = file.name | |
| name = os.path.basename(path) | |
| file_names.append(name) | |
| if path.endswith(".pdf"): | |
| raw_text = extract_text_from_pdf(path) | |
| else: | |
| raw_text = extract_text_from_txt(path) | |
| if not raw_text.strip(): | |
| continue | |
| chunks = chunk_text(raw_text) | |
| for i, chunk in enumerate(chunks): | |
| chunk_id = f"{name}_chunk_{i}" | |
| embedding = embedder.encode(chunk).tolist() | |
| # Add to ChromaDB | |
| collection.add( | |
| ids=[chunk_id], | |
| embeddings=[embedding], | |
| documents=[chunk], | |
| metadatas=[{"source": name, "chunk_index": i}] | |
| ) | |
| # Add to doc_store for BM25 | |
| doc_store.append({"id": chunk_id, "text": chunk, "source": name}) | |
| total_chunks += 1 | |
| return f"β Ingested {total_chunks} chunks from: {', '.join(file_names)}" | |
| # ββ Retrieval ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def vector_search(query, top_k=10): | |
| query_embedding = embedder.encode(query).tolist() | |
| results = collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=min(top_k, len(doc_store)) | |
| ) | |
| chunks = [] | |
| for doc, meta in zip(results["documents"][0], results["metadatas"][0]): | |
| chunks.append({"text": doc, "source": meta["source"]}) | |
| return chunks | |
| def bm25_search(query, top_k=10): | |
| if not doc_store: | |
| return [] | |
| tokenized_corpus = [d["text"].lower().split() for d in doc_store] | |
| bm25 = BM25Okapi(tokenized_corpus) | |
| tokenized_query = query.lower().split() | |
| scores = bm25.get_scores(tokenized_query) | |
| top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k] | |
| return [{"text": doc_store[i]["text"], "source": doc_store[i]["source"]} for i in top_indices] | |
| def reciprocal_rank_fusion(vector_results, bm25_results, k=60): | |
| """Merge two ranked lists using RRF.""" | |
| scores = {} | |
| all_chunks = {} | |
| for rank, chunk in enumerate(vector_results): | |
| key = chunk["text"][:100] | |
| scores[key] = scores.get(key, 0) + 1 / (k + rank + 1) | |
| all_chunks[key] = chunk | |
| for rank, chunk in enumerate(bm25_results): | |
| key = chunk["text"][:100] | |
| scores[key] = scores.get(key, 0) + 1 / (k + rank + 1) | |
| all_chunks[key] = chunk | |
| sorted_keys = sorted(scores, key=lambda x: scores[x], reverse=True) | |
| return [all_chunks[k] for k in sorted_keys] | |
| def rerank(query, chunks, top_k=5): | |
| """Cross-encoder reranking.""" | |
| if not chunks: | |
| return [] | |
| pairs = [(query, c["text"]) for c in chunks] | |
| scores = reranker.predict(pairs) | |
| ranked = sorted(zip(scores, chunks), key=lambda x: x[0], reverse=True) | |
| return [c for _, c in ranked[:top_k]] | |
| # ββ Generate answer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_answer(query, chunks): | |
| context_parts = [] | |
| for i, chunk in enumerate(chunks): | |
| context_parts.append(f"[Source {i+1}: {chunk['source']}]\n{chunk['text']}") | |
| context = "\n\n".join(context_parts) | |
| prompt = f"""You are a helpful research assistant. Answer the question based ONLY on the provided context. | |
| For every claim you make, cite the source using [Source N] notation. | |
| If the answer is not in the context, say "I couldn't find relevant information in the uploaded documents." | |
| Question: {query} | |
| Context: | |
| {context} | |
| Answer (with citations):""" | |
| response = groq_client.chat.completions.create( | |
| model="llama-3.3-70b-versatile", | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=1024 | |
| ) | |
| return response.choices[0].message.content | |
| # ββ Main pipeline βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def ask_question(query): | |
| if not query.strip(): | |
| return "β οΈ Please enter a question.", "" | |
| if not doc_store: | |
| return "β οΈ Please upload and ingest documents first.", "" | |
| # Step 1: Hybrid retrieval | |
| vector_results = vector_search(query, top_k=10) | |
| bm25_results = bm25_search(query, top_k=10) | |
| # Step 2: RRF fusion | |
| fused = reciprocal_rank_fusion(vector_results, bm25_results) | |
| # Step 3: Rerank | |
| top_chunks = rerank(query, fused, top_k=5) | |
| # Step 4: Generate answer | |
| answer = generate_answer(query, top_chunks) | |
| # Step 5: Format sources | |
| seen = [] | |
| for chunk in top_chunks: | |
| if chunk["source"] not in seen: | |
| seen.append(chunk["source"]) | |
| sources = "\n".join([f"β’ {s}" for s in seen]) | |
| return answer, f"π Sources referenced:\n{sources}" | |
| # ββ Gradio UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="Domain RAG System β Ask My Doc") as demo: | |
| gr.Markdown("# π Ask My Doc β Domain-Specific RAG System") | |
| gr.Markdown( | |
| "Upload your documents (PDF or TXT), then ask questions. " | |
| "Uses **hybrid retrieval** (vector search + BM25) and **cross-encoder reranking** " | |
| "to give you cited, accurate answers." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Step 1: Upload Documents") | |
| file_upload = gr.File( | |
| label="Upload PDF or TXT files", | |
| file_types=[".pdf", ".txt"], | |
| file_count="multiple" | |
| ) | |
| ingest_btn = gr.Button("π₯ Ingest Documents", variant="primary") | |
| ingest_status = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=2): | |
| gr.Markdown("### Step 2: Ask a Question") | |
| question_input = gr.Textbox( | |
| label="Your Question", | |
| placeholder="e.g. What are the key findings in this document?", | |
| lines=2 | |
| ) | |
| ask_btn = gr.Button("π Get Answer", variant="primary") | |
| answer_output = gr.Textbox(label="Answer", lines=10, interactive=False) | |
| sources_output = gr.Textbox(label="Sources", lines=3, interactive=False) | |
| ingest_btn.click(fn=ingest_documents, inputs=[file_upload], outputs=[ingest_status]) | |
| ask_btn.click(fn=ask_question, inputs=[question_input], outputs=[answer_output, sources_output]) | |
| demo.launch() |