Spaces:
Sleeping
Sleeping
File size: 4,920 Bytes
5aa0be0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | import os
import getpass
import torch
from dotenv import load_dotenv
import gradio as gr
import faiss
from typing import List, TypedDict
from sentence_transformers import CrossEncoder
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_core.prompts import ChatPromptTemplate
from tabulate import tabulate
from langchain.chat_models import init_chat_model
from langchain_huggingface import HuggingFaceEmbeddings
load_dotenv()
PROJECT_PATH = os.path.dirname(os.path.abspath(__file__))
FAISS_INDEX_DIR = os.path.join(PROJECT_PATH, "faiss_index")
os.makedirs(FAISS_INDEX_DIR, exist_ok=True)
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
LLM_MODEL = 'gemini-2.5-flash'
RERANKER_MODEL = "BAAI/bge-reranker-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
reranker_model = CrossEncoder(RERANKER_MODEL, device=device)
if not os.environ.get("GOOGLE_API_KEY"):
os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
llm = init_chat_model(LLM_MODEL, model_provider="google_genai")
print("Loading FAISS index...")
vector_store = FAISS.load_local(
FAISS_INDEX_DIR,
embeddings=embeddings,
allow_dangerous_deserialization=True,
)
print("FAISS index loaded successfully")
class State(TypedDict):
question: str
context: List
answer: str
def retrieve(question, faiss_broad_k=30, metadata_min_keep=8, top_k=5, semantic_threshold=0.35):
docs_with_scores = vector_store.similarity_search_with_score(question, k=faiss_broad_k)
print(f"[retrieve] FAISS returned {len(docs_with_scores)} candidates")
SCOPE_B_KEYWORDS = [
"machine learning", "artificial intelligence", "deep learning", "robotics",
"data science", "neural network", "quantum computing", "automation",
"computer vision", "nlp", "natural language", "algorithm", "software",
"engineering", "big data", "reinforcement learning"
]
meta_filtered = []
for doc, _ in docs_with_scores:
text = f"{doc.metadata.get('title','')} {doc.metadata.get('concepts','')} {doc.page_content}".lower()
if any(k in text for k in SCOPE_B_KEYWORDS):
meta_filtered.append(doc)
if len(meta_filtered) == 0:
qlow = question.lower()
if not any(k in qlow for k in SCOPE_B_KEYWORDS):
return []
meta_filtered = [doc for doc, _ in docs_with_scores[:metadata_min_keep]]
rerank_inputs = [[question, doc.page_content] for doc in meta_filtered]
scores = reranker_model.predict(rerank_inputs)
reranked = sorted(zip(meta_filtered, scores), key=lambda x: x[1], reverse=True)
filtered = [(doc, s) for doc, s in reranked if s >= semantic_threshold]
final_docs = [doc for doc, _ in filtered[:top_k]]
return final_docs
def generate_with_table(question: str, docs: List):
if not docs:
return f"I cannot find anything based on the search term **{question}**."
rows = []
for d in docs:
m = d.metadata
rows.append([
m.get("title", ""),
m.get("pub_year", ""),
m.get("authors", ""),
m.get("concepts", "")
])
headers = ["Title", "Year", "Authors", "Concepts"]
papers_table = tabulate(rows, headers=headers, tablefmt="pipe")
docs_content = "\n\n".join(doc.page_content for doc in docs)
prompt_template = ChatPromptTemplate.from_messages([
(
"system",
"You are an expert RAG system. Answer the user's question based ONLY on the provided context. "
"After your answer, append a Markdown table of the retrieved papers. "
"If there are results: say how many you found for [Search term]. "
"If none: say 'I cannot find anything based on the search term [Search term]'. "
"Output format: [Summary] \n\n [Markdown Table]. Context: {context}"
),
("human", "Question: {question}. Table of papers: \n\n{papers_table}"),
])
messages = prompt_template.invoke({
"question": question,
"context": docs_content,
"papers_table": papers_table
})
response = llm.invoke(messages)
return response.content
def rag_pipeline(question: str):
if not question.strip():
return "Please enter a question."
docs = retrieve(question)
response = generate_with_table(question, docs)
return response
demo = gr.Interface(
fn=rag_pipeline,
inputs=gr.Textbox(label="Enter your research query:", placeholder="e.g., deep learning in robotics"),
outputs=gr.Markdown(label="Response"),
title="📚 Research Paper RAG Assistant",
description="Retrieves and summarizes papers related to your query using FAISS, CrossEncoder, and Gemini."
)
if __name__ == "__main__":
demo.launch()
|