File size: 4,920 Bytes
5aa0be0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import getpass
import torch
from dotenv import load_dotenv
import gradio as gr
import faiss
from typing import List, TypedDict
from sentence_transformers import CrossEncoder
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_core.prompts import ChatPromptTemplate
from tabulate import tabulate
from langchain.chat_models import init_chat_model
from langchain_huggingface import HuggingFaceEmbeddings

load_dotenv()

PROJECT_PATH = os.path.dirname(os.path.abspath(__file__))
FAISS_INDEX_DIR = os.path.join(PROJECT_PATH, "faiss_index")
os.makedirs(FAISS_INDEX_DIR, exist_ok=True)
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
LLM_MODEL = 'gemini-2.5-flash'
RERANKER_MODEL = "BAAI/bge-reranker-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
reranker_model = CrossEncoder(RERANKER_MODEL, device=device)

if not os.environ.get("GOOGLE_API_KEY"):
    os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")

embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
llm = init_chat_model(LLM_MODEL, model_provider="google_genai")

print("Loading FAISS index...")
vector_store = FAISS.load_local(
    FAISS_INDEX_DIR,
    embeddings=embeddings,
    allow_dangerous_deserialization=True,
)

print("FAISS index loaded successfully")

class State(TypedDict):
    question: str
    context: List
    answer: str

def retrieve(question, faiss_broad_k=30, metadata_min_keep=8, top_k=5, semantic_threshold=0.35):
    docs_with_scores = vector_store.similarity_search_with_score(question, k=faiss_broad_k)
    print(f"[retrieve] FAISS returned {len(docs_with_scores)} candidates")

    SCOPE_B_KEYWORDS = [
        "machine learning", "artificial intelligence", "deep learning", "robotics",
        "data science", "neural network", "quantum computing", "automation",
        "computer vision", "nlp", "natural language", "algorithm", "software",
        "engineering", "big data", "reinforcement learning"
    ]

    meta_filtered = []
    for doc, _ in docs_with_scores:
        text = f"{doc.metadata.get('title','')} {doc.metadata.get('concepts','')} {doc.page_content}".lower()
        if any(k in text for k in SCOPE_B_KEYWORDS):
            meta_filtered.append(doc)

    if len(meta_filtered) == 0:
        qlow = question.lower()
        if not any(k in qlow for k in SCOPE_B_KEYWORDS):
            return []
        meta_filtered = [doc for doc, _ in docs_with_scores[:metadata_min_keep]]

    rerank_inputs = [[question, doc.page_content] for doc in meta_filtered]
    scores = reranker_model.predict(rerank_inputs)
    reranked = sorted(zip(meta_filtered, scores), key=lambda x: x[1], reverse=True)

    filtered = [(doc, s) for doc, s in reranked if s >= semantic_threshold]
    final_docs = [doc for doc, _ in filtered[:top_k]]

    return final_docs

def generate_with_table(question: str, docs: List):
    if not docs:
        return f"I cannot find anything based on the search term **{question}**."

    rows = []
    for d in docs:
        m = d.metadata
        rows.append([
            m.get("title", ""),
            m.get("pub_year", ""),
            m.get("authors", ""),
            m.get("concepts", "")
        ])
    headers = ["Title", "Year", "Authors", "Concepts"]
    papers_table = tabulate(rows, headers=headers, tablefmt="pipe")

    docs_content = "\n\n".join(doc.page_content for doc in docs)

    prompt_template = ChatPromptTemplate.from_messages([
        (
            "system",
            "You are an expert RAG system. Answer the user's question based ONLY on the provided context. "
            "After your answer, append a Markdown table of the retrieved papers. "
            "If there are results: say how many you found for [Search term]. "
            "If none: say 'I cannot find anything based on the search term [Search term]'. "
            "Output format: [Summary] \n\n [Markdown Table]. Context: {context}"
        ),
        ("human", "Question: {question}. Table of papers: \n\n{papers_table}"),
    ])

    messages = prompt_template.invoke({
        "question": question,
        "context": docs_content,
        "papers_table": papers_table
    })

    response = llm.invoke(messages)
    return response.content

def rag_pipeline(question: str):
    if not question.strip():
        return "Please enter a question."

    docs = retrieve(question)
    response = generate_with_table(question, docs)
    return response

demo = gr.Interface(
    fn=rag_pipeline,
    inputs=gr.Textbox(label="Enter your research query:", placeholder="e.g., deep learning in robotics"),
    outputs=gr.Markdown(label="Response"),
    title="📚 Research Paper RAG Assistant",
    description="Retrieves and summarizes papers related to your query using FAISS, CrossEncoder, and Gemini."
)

if __name__ == "__main__":
    demo.launch()