import os import getpass import torch from dotenv import load_dotenv import gradio as gr import faiss from typing import List, TypedDict from sentence_transformers import CrossEncoder from langchain_community.vectorstores import FAISS from langchain_community.docstore.in_memory import InMemoryDocstore from langchain_core.prompts import ChatPromptTemplate from tabulate import tabulate from langchain.chat_models import init_chat_model from langchain_huggingface import HuggingFaceEmbeddings load_dotenv() PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) FAISS_INDEX_DIR = os.path.join(PROJECT_PATH, "faiss_index") os.makedirs(FAISS_INDEX_DIR, exist_ok=True) EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2' LLM_MODEL = 'gemini-2.5-flash' RERANKER_MODEL = "BAAI/bge-reranker-base" device = "cuda" if torch.cuda.is_available() else "cpu" reranker_model = CrossEncoder(RERANKER_MODEL, device=device) if not os.environ.get("GOOGLE_API_KEY"): os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ") embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) llm = init_chat_model(LLM_MODEL, model_provider="google_genai") print("Loading FAISS index...") vector_store = FAISS.load_local( FAISS_INDEX_DIR, embeddings=embeddings, allow_dangerous_deserialization=True, ) print("FAISS index loaded successfully") class State(TypedDict): question: str context: List answer: str def retrieve(question, faiss_broad_k=30, metadata_min_keep=8, top_k=5, semantic_threshold=0.35): docs_with_scores = vector_store.similarity_search_with_score(question, k=faiss_broad_k) print(f"[retrieve] FAISS returned {len(docs_with_scores)} candidates") SCOPE_B_KEYWORDS = [ "machine learning", "artificial intelligence", "deep learning", "robotics", "data science", "neural network", "quantum computing", "automation", "computer vision", "nlp", "natural language", "algorithm", "software", "engineering", "big data", "reinforcement learning" ] meta_filtered = [] for doc, _ in docs_with_scores: text = f"{doc.metadata.get('title','')} {doc.metadata.get('concepts','')} {doc.page_content}".lower() if any(k in text for k in SCOPE_B_KEYWORDS): meta_filtered.append(doc) if len(meta_filtered) == 0: qlow = question.lower() if not any(k in qlow for k in SCOPE_B_KEYWORDS): return [] meta_filtered = [doc for doc, _ in docs_with_scores[:metadata_min_keep]] rerank_inputs = [[question, doc.page_content] for doc in meta_filtered] scores = reranker_model.predict(rerank_inputs) reranked = sorted(zip(meta_filtered, scores), key=lambda x: x[1], reverse=True) filtered = [(doc, s) for doc, s in reranked if s >= semantic_threshold] final_docs = [doc for doc, _ in filtered[:top_k]] return final_docs def generate_with_table(question: str, docs: List): if not docs: return f"I cannot find anything based on the search term **{question}**." rows = [] for d in docs: m = d.metadata rows.append([ m.get("title", ""), m.get("pub_year", ""), m.get("authors", ""), m.get("concepts", "") ]) headers = ["Title", "Year", "Authors", "Concepts"] papers_table = tabulate(rows, headers=headers, tablefmt="pipe") docs_content = "\n\n".join(doc.page_content for doc in docs) prompt_template = ChatPromptTemplate.from_messages([ ( "system", "You are an expert RAG system. Answer the user's question based ONLY on the provided context. " "After your answer, append a Markdown table of the retrieved papers. " "If there are results: say how many you found for [Search term]. " "If none: say 'I cannot find anything based on the search term [Search term]'. " "Output format: [Summary] \n\n [Markdown Table]. Context: {context}" ), ("human", "Question: {question}. Table of papers: \n\n{papers_table}"), ]) messages = prompt_template.invoke({ "question": question, "context": docs_content, "papers_table": papers_table }) response = llm.invoke(messages) return response.content def rag_pipeline(question: str): if not question.strip(): return "Please enter a question." docs = retrieve(question) response = generate_with_table(question, docs) return response demo = gr.Interface( fn=rag_pipeline, inputs=gr.Textbox(label="Enter your research query:", placeholder="e.g., deep learning in robotics"), outputs=gr.Markdown(label="Response"), title="📚 Research Paper RAG Assistant", description="Retrieves and summarizes papers related to your query using FAISS, CrossEncoder, and Gemini." ) if __name__ == "__main__": demo.launch()