Spaces:
Sleeping
Sleeping
| # arXiv RAG Question Answering app | |
| # Hybrid retrieval (BM25 + semantic search) with reranking and Groq LLM | |
| from datasets import load_dataset | |
| from sentence_transformers import SentenceTransformer, CrossEncoder, util | |
| from rank_bm25 import BM25Okapi | |
| from litellm import completion | |
| import gradio as gr | |
| import numpy as np | |
| import re | |
| SUBSET_SIZE = 150 | |
| GROQ_MODEL = "groq/llama-3.3-70b-versatile" | |
| dataset = load_dataset( | |
| "nick007x/arxiv-papers", | |
| split="train", | |
| streaming=True | |
| ) | |
| subset = [] | |
| for i, record in enumerate(dataset): | |
| if i >= SUBSET_SIZE: | |
| break | |
| subset.append(record) | |
| RAW_DOCUMENTS = [] | |
| for record in subset: | |
| abstract = record.get("abstract", "") | |
| if abstract and len(abstract) > 120: | |
| RAW_DOCUMENTS.append({ | |
| "id": record.get("arxiv_id", "unknown"), | |
| "title": record.get("title", "No title"), | |
| "text": abstract, | |
| "authors": record.get("authors", []), | |
| "primary_subject": record.get("primary_subject", "") | |
| }) | |
| def tokenize(text: str): | |
| return re.findall(r"\w+", text.lower()) | |
| def create_chunks(text, chunk_size=200, overlap=40): | |
| tokens = tokenize(text) | |
| chunks = [] | |
| step = chunk_size - overlap | |
| for i in range(0, len(tokens), step): | |
| piece = tokens[i:i + chunk_size] | |
| if len(piece) < 60: | |
| continue | |
| chunks.append(" ".join(piece)) | |
| return chunks | |
| CHUNKS = [] | |
| for doc in RAW_DOCUMENTS: | |
| for ch in create_chunks(doc["text"]): | |
| CHUNKS.append({ | |
| "doc_id": doc["id"], | |
| "title": doc["title"], | |
| "text": ch, | |
| "meta": doc | |
| }) | |
| bm25_corpus = [tokenize(chunk["text"]) for chunk in CHUNKS] | |
| BM25 = BM25Okapi(bm25_corpus) | |
| SEM_MODEL = SentenceTransformer("all-MiniLM-L6-v2") | |
| chunk_texts = [chunk["text"] for chunk in CHUNKS] | |
| CHUNK_EMB = SEM_MODEL.encode(chunk_texts, convert_to_tensor=True) | |
| RERANK = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| def retrieve_chunks(query, mode="hybrid", top_k=8): | |
| query_tokens = tokenize(query) | |
| candidate_indices = set() | |
| if mode in ("hybrid", "bm25"): | |
| scores = BM25.get_scores(query_tokens) | |
| top = np.argsort(scores)[::-1][:top_k] | |
| candidate_indices.update(top) | |
| if mode in ("hybrid", "semantic"): | |
| query_embedding = SEM_MODEL.encode(query, convert_to_tensor=True) | |
| cos_scores = util.cos_sim(query_embedding, CHUNK_EMB)[0].cpu().numpy() | |
| top = np.argsort(cos_scores)[::-1][:top_k] | |
| candidate_indices.update(top) | |
| candidate_indices = list(candidate_indices) | |
| pairs = [(query, CHUNKS[i]["text"]) for i in candidate_indices] | |
| rerank_scores = RERANK.predict(pairs) | |
| sorted_idx = np.argsort(rerank_scores)[::-1][:5] | |
| final_indices = [candidate_indices[i] for i in sorted_idx] | |
| return [CHUNKS[i] for i in final_indices] | |
| def format_context(chunks): | |
| result = [] | |
| for i, chunk in enumerate(chunks, start=1): | |
| result.append(f"[{i}] ({chunk['doc_id']}) {chunk['text'][:300]}...") | |
| return "\n\n".join(result) | |
| def generate_answer(api_key, query, chunks): | |
| if not api_key or not api_key.strip(): | |
| return "Please enter your Groq API key.", "" | |
| context = format_context(chunks) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": ( | |
| "You are a retrieval-augmented assistant answering questions about arXiv papers. " | |
| "Use only the provided context and cite sources as [1], [2]. " | |
| "If the answer is not present, say that you do not know." | |
| ) | |
| }, | |
| { | |
| "role": "user", | |
| "content": f"Question:\n{query}\n\nContext:\n{context}" | |
| } | |
| ] | |
| try: | |
| response = completion( | |
| model=GROQ_MODEL, | |
| api_key=api_key, | |
| messages=messages, | |
| max_tokens=400, | |
| temperature=0.2 | |
| ) | |
| return response.choices[0].message["content"], context | |
| except Exception as e: | |
| return f"LLM error: {e}", "" | |
| def rag_pipeline(question, mode, api_key): | |
| chunks = retrieve_chunks(question, mode=mode) | |
| return generate_answer(api_key, question, chunks) | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| "# 🔭 arXiv RAG Question Answering\n" | |
| "Hybrid BM25 and semantic retrieval over arXiv abstracts." | |
| ) | |
| api_key = gr.Textbox( | |
| label="Groq API Key", | |
| type="password" | |
| ) | |
| question = gr.Textbox( | |
| label="Question", | |
| lines=3 | |
| ) | |
| mode = gr.Radio( | |
| choices=["hybrid", "bm25", "semantic"], | |
| value="hybrid", | |
| label="Retrieval mode" | |
| ) | |
| ask_button = gr.Button("Ask") | |
| answer_md = gr.Markdown() | |
| sources_md = gr.Markdown() | |
| ask_button.click( | |
| fn=rag_pipeline, | |
| inputs=[question, mode, api_key], | |
| outputs=[answer_md, sources_md] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |