rag-project / app.py
andrukhovay's picture
Upload folder using huggingface_hub
2d8f8d1 verified
# arXiv RAG Question Answering app
# Hybrid retrieval (BM25 + semantic search) with reranking and Groq LLM
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder, util
from rank_bm25 import BM25Okapi
from litellm import completion
import gradio as gr
import numpy as np
import re
SUBSET_SIZE = 150
GROQ_MODEL = "groq/llama-3.3-70b-versatile"
dataset = load_dataset(
"nick007x/arxiv-papers",
split="train",
streaming=True
)
subset = []
for i, record in enumerate(dataset):
if i >= SUBSET_SIZE:
break
subset.append(record)
RAW_DOCUMENTS = []
for record in subset:
abstract = record.get("abstract", "")
if abstract and len(abstract) > 120:
RAW_DOCUMENTS.append({
"id": record.get("arxiv_id", "unknown"),
"title": record.get("title", "No title"),
"text": abstract,
"authors": record.get("authors", []),
"primary_subject": record.get("primary_subject", "")
})
def tokenize(text: str):
return re.findall(r"\w+", text.lower())
def create_chunks(text, chunk_size=200, overlap=40):
tokens = tokenize(text)
chunks = []
step = chunk_size - overlap
for i in range(0, len(tokens), step):
piece = tokens[i:i + chunk_size]
if len(piece) < 60:
continue
chunks.append(" ".join(piece))
return chunks
CHUNKS = []
for doc in RAW_DOCUMENTS:
for ch in create_chunks(doc["text"]):
CHUNKS.append({
"doc_id": doc["id"],
"title": doc["title"],
"text": ch,
"meta": doc
})
bm25_corpus = [tokenize(chunk["text"]) for chunk in CHUNKS]
BM25 = BM25Okapi(bm25_corpus)
SEM_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
chunk_texts = [chunk["text"] for chunk in CHUNKS]
CHUNK_EMB = SEM_MODEL.encode(chunk_texts, convert_to_tensor=True)
RERANK = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
def retrieve_chunks(query, mode="hybrid", top_k=8):
query_tokens = tokenize(query)
candidate_indices = set()
if mode in ("hybrid", "bm25"):
scores = BM25.get_scores(query_tokens)
top = np.argsort(scores)[::-1][:top_k]
candidate_indices.update(top)
if mode in ("hybrid", "semantic"):
query_embedding = SEM_MODEL.encode(query, convert_to_tensor=True)
cos_scores = util.cos_sim(query_embedding, CHUNK_EMB)[0].cpu().numpy()
top = np.argsort(cos_scores)[::-1][:top_k]
candidate_indices.update(top)
candidate_indices = list(candidate_indices)
pairs = [(query, CHUNKS[i]["text"]) for i in candidate_indices]
rerank_scores = RERANK.predict(pairs)
sorted_idx = np.argsort(rerank_scores)[::-1][:5]
final_indices = [candidate_indices[i] for i in sorted_idx]
return [CHUNKS[i] for i in final_indices]
def format_context(chunks):
result = []
for i, chunk in enumerate(chunks, start=1):
result.append(f"[{i}] ({chunk['doc_id']}) {chunk['text'][:300]}...")
return "\n\n".join(result)
def generate_answer(api_key, query, chunks):
if not api_key or not api_key.strip():
return "Please enter your Groq API key.", ""
context = format_context(chunks)
messages = [
{
"role": "system",
"content": (
"You are a retrieval-augmented assistant answering questions about arXiv papers. "
"Use only the provided context and cite sources as [1], [2]. "
"If the answer is not present, say that you do not know."
)
},
{
"role": "user",
"content": f"Question:\n{query}\n\nContext:\n{context}"
}
]
try:
response = completion(
model=GROQ_MODEL,
api_key=api_key,
messages=messages,
max_tokens=400,
temperature=0.2
)
return response.choices[0].message["content"], context
except Exception as e:
return f"LLM error: {e}", ""
def rag_pipeline(question, mode, api_key):
chunks = retrieve_chunks(question, mode=mode)
return generate_answer(api_key, question, chunks)
with gr.Blocks() as demo:
gr.Markdown(
"# 🔭 arXiv RAG Question Answering\n"
"Hybrid BM25 and semantic retrieval over arXiv abstracts."
)
api_key = gr.Textbox(
label="Groq API Key",
type="password"
)
question = gr.Textbox(
label="Question",
lines=3
)
mode = gr.Radio(
choices=["hybrid", "bm25", "semantic"],
value="hybrid",
label="Retrieval mode"
)
ask_button = gr.Button("Ask")
answer_md = gr.Markdown()
sources_md = gr.Markdown()
ask_button.click(
fn=rag_pipeline,
inputs=[question, mode, api_key],
outputs=[answer_md, sources_md]
)
if __name__ == "__main__":
demo.launch()