Spaces:
Sleeping
Sleeping
| import os | |
| import getpass | |
| import torch | |
| from dotenv import load_dotenv | |
| import gradio as gr | |
| import faiss | |
| from typing import List, TypedDict | |
| from sentence_transformers import CrossEncoder | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.docstore.in_memory import InMemoryDocstore | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from tabulate import tabulate | |
| from langchain.chat_models import init_chat_model | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| load_dotenv() | |
| PROJECT_PATH = os.path.dirname(os.path.abspath(__file__)) | |
| FAISS_INDEX_DIR = os.path.join(PROJECT_PATH, "faiss_index") | |
| os.makedirs(FAISS_INDEX_DIR, exist_ok=True) | |
| EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2' | |
| LLM_MODEL = 'gemini-2.5-flash' | |
| RERANKER_MODEL = "BAAI/bge-reranker-base" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| reranker_model = CrossEncoder(RERANKER_MODEL, device=device) | |
| if not os.environ.get("GOOGLE_API_KEY"): | |
| os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ") | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) | |
| llm = init_chat_model(LLM_MODEL, model_provider="google_genai") | |
| print("Loading FAISS index...") | |
| vector_store = FAISS.load_local( | |
| FAISS_INDEX_DIR, | |
| embeddings=embeddings, | |
| allow_dangerous_deserialization=True, | |
| ) | |
| print("FAISS index loaded successfully") | |
| class State(TypedDict): | |
| question: str | |
| context: List | |
| answer: str | |
| def retrieve(question, faiss_broad_k=30, metadata_min_keep=8, top_k=5, semantic_threshold=0.35): | |
| docs_with_scores = vector_store.similarity_search_with_score(question, k=faiss_broad_k) | |
| print(f"[retrieve] FAISS returned {len(docs_with_scores)} candidates") | |
| SCOPE_B_KEYWORDS = [ | |
| "machine learning", "artificial intelligence", "deep learning", "robotics", | |
| "data science", "neural network", "quantum computing", "automation", | |
| "computer vision", "nlp", "natural language", "algorithm", "software", | |
| "engineering", "big data", "reinforcement learning" | |
| ] | |
| meta_filtered = [] | |
| for doc, _ in docs_with_scores: | |
| text = f"{doc.metadata.get('title','')} {doc.metadata.get('concepts','')} {doc.page_content}".lower() | |
| if any(k in text for k in SCOPE_B_KEYWORDS): | |
| meta_filtered.append(doc) | |
| if len(meta_filtered) == 0: | |
| qlow = question.lower() | |
| if not any(k in qlow for k in SCOPE_B_KEYWORDS): | |
| return [] | |
| meta_filtered = [doc for doc, _ in docs_with_scores[:metadata_min_keep]] | |
| rerank_inputs = [[question, doc.page_content] for doc in meta_filtered] | |
| scores = reranker_model.predict(rerank_inputs) | |
| reranked = sorted(zip(meta_filtered, scores), key=lambda x: x[1], reverse=True) | |
| filtered = [(doc, s) for doc, s in reranked if s >= semantic_threshold] | |
| final_docs = [doc for doc, _ in filtered[:top_k]] | |
| return final_docs | |
| def generate_with_table(question: str, docs: List): | |
| if not docs: | |
| return f"I cannot find anything based on the search term **{question}**." | |
| rows = [] | |
| for d in docs: | |
| m = d.metadata | |
| rows.append([ | |
| m.get("title", ""), | |
| m.get("pub_year", ""), | |
| m.get("authors", ""), | |
| m.get("concepts", "") | |
| ]) | |
| headers = ["Title", "Year", "Authors", "Concepts"] | |
| papers_table = tabulate(rows, headers=headers, tablefmt="pipe") | |
| docs_content = "\n\n".join(doc.page_content for doc in docs) | |
| prompt_template = ChatPromptTemplate.from_messages([ | |
| ( | |
| "system", | |
| "You are an expert RAG system. Answer the user's question based ONLY on the provided context. " | |
| "After your answer, append a Markdown table of the retrieved papers. " | |
| "If there are results: say how many you found for [Search term]. " | |
| "If none: say 'I cannot find anything based on the search term [Search term]'. " | |
| "Output format: [Summary] \n\n [Markdown Table]. Context: {context}" | |
| ), | |
| ("human", "Question: {question}. Table of papers: \n\n{papers_table}"), | |
| ]) | |
| messages = prompt_template.invoke({ | |
| "question": question, | |
| "context": docs_content, | |
| "papers_table": papers_table | |
| }) | |
| response = llm.invoke(messages) | |
| return response.content | |
| def rag_pipeline(question: str): | |
| if not question.strip(): | |
| return "Please enter a question." | |
| docs = retrieve(question) | |
| response = generate_with_table(question, docs) | |
| return response | |
| demo = gr.Interface( | |
| fn=rag_pipeline, | |
| inputs=gr.Textbox(label="Enter your research query:", placeholder="e.g., deep learning in robotics"), | |
| outputs=gr.Markdown(label="Response"), | |
| title="📚 Research Paper RAG Assistant", | |
| description="Retrieves and summarizes papers related to your query using FAISS, CrossEncoder, and Gemini." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |