nlp-assignment / main.py
xoa-the-noob
a new beginning
5aa0be0
import os
import getpass
import torch
from dotenv import load_dotenv
import gradio as gr
import faiss
from typing import List, TypedDict
from sentence_transformers import CrossEncoder
from langchain_community.vectorstores import FAISS
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_core.prompts import ChatPromptTemplate
from tabulate import tabulate
from langchain.chat_models import init_chat_model
from langchain_huggingface import HuggingFaceEmbeddings
load_dotenv()
PROJECT_PATH = os.path.dirname(os.path.abspath(__file__))
FAISS_INDEX_DIR = os.path.join(PROJECT_PATH, "faiss_index")
os.makedirs(FAISS_INDEX_DIR, exist_ok=True)
EMBEDDING_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
LLM_MODEL = 'gemini-2.5-flash'
RERANKER_MODEL = "BAAI/bge-reranker-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
reranker_model = CrossEncoder(RERANKER_MODEL, device=device)
if not os.environ.get("GOOGLE_API_KEY"):
os.environ["GOOGLE_API_KEY"] = getpass.getpass("Enter API key for Google Gemini: ")
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
llm = init_chat_model(LLM_MODEL, model_provider="google_genai")
print("Loading FAISS index...")
vector_store = FAISS.load_local(
FAISS_INDEX_DIR,
embeddings=embeddings,
allow_dangerous_deserialization=True,
)
print("FAISS index loaded successfully")
class State(TypedDict):
question: str
context: List
answer: str
def retrieve(question, faiss_broad_k=30, metadata_min_keep=8, top_k=5, semantic_threshold=0.35):
docs_with_scores = vector_store.similarity_search_with_score(question, k=faiss_broad_k)
print(f"[retrieve] FAISS returned {len(docs_with_scores)} candidates")
SCOPE_B_KEYWORDS = [
"machine learning", "artificial intelligence", "deep learning", "robotics",
"data science", "neural network", "quantum computing", "automation",
"computer vision", "nlp", "natural language", "algorithm", "software",
"engineering", "big data", "reinforcement learning"
]
meta_filtered = []
for doc, _ in docs_with_scores:
text = f"{doc.metadata.get('title','')} {doc.metadata.get('concepts','')} {doc.page_content}".lower()
if any(k in text for k in SCOPE_B_KEYWORDS):
meta_filtered.append(doc)
if len(meta_filtered) == 0:
qlow = question.lower()
if not any(k in qlow for k in SCOPE_B_KEYWORDS):
return []
meta_filtered = [doc for doc, _ in docs_with_scores[:metadata_min_keep]]
rerank_inputs = [[question, doc.page_content] for doc in meta_filtered]
scores = reranker_model.predict(rerank_inputs)
reranked = sorted(zip(meta_filtered, scores), key=lambda x: x[1], reverse=True)
filtered = [(doc, s) for doc, s in reranked if s >= semantic_threshold]
final_docs = [doc for doc, _ in filtered[:top_k]]
return final_docs
def generate_with_table(question: str, docs: List):
if not docs:
return f"I cannot find anything based on the search term **{question}**."
rows = []
for d in docs:
m = d.metadata
rows.append([
m.get("title", ""),
m.get("pub_year", ""),
m.get("authors", ""),
m.get("concepts", "")
])
headers = ["Title", "Year", "Authors", "Concepts"]
papers_table = tabulate(rows, headers=headers, tablefmt="pipe")
docs_content = "\n\n".join(doc.page_content for doc in docs)
prompt_template = ChatPromptTemplate.from_messages([
(
"system",
"You are an expert RAG system. Answer the user's question based ONLY on the provided context. "
"After your answer, append a Markdown table of the retrieved papers. "
"If there are results: say how many you found for [Search term]. "
"If none: say 'I cannot find anything based on the search term [Search term]'. "
"Output format: [Summary] \n\n [Markdown Table]. Context: {context}"
),
("human", "Question: {question}. Table of papers: \n\n{papers_table}"),
])
messages = prompt_template.invoke({
"question": question,
"context": docs_content,
"papers_table": papers_table
})
response = llm.invoke(messages)
return response.content
def rag_pipeline(question: str):
if not question.strip():
return "Please enter a question."
docs = retrieve(question)
response = generate_with_table(question, docs)
return response
demo = gr.Interface(
fn=rag_pipeline,
inputs=gr.Textbox(label="Enter your research query:", placeholder="e.g., deep learning in robotics"),
outputs=gr.Markdown(label="Response"),
title="📚 Research Paper RAG Assistant",
description="Retrieves and summarizes papers related to your query using FAISS, CrossEncoder, and Gemini."
)
if __name__ == "__main__":
demo.launch()