BioRAG / main.py
PranavReddy18's picture
Upload 8 files
ca767c0 verified
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
from dotenv import load_dotenv
import os
import uvicorn
load_dotenv()
app = FastAPI(title="A RAG-Driven Learning Assistant for Biology")
from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document, BaseRetriever
from sentence_transformers import CrossEncoder
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
loader = DirectoryLoader('data/', glob="**/*.pdf", show_progress=True, loader_cls=PyPDFLoader)
documents = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
chunks = splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(chunks, embeddings)
GROQ_API_KEY = os.getenv('GROQ_API_KEY')
if not GROQ_API_KEY:
raise ValueError("GROQ_API_KEY is not set in the environment variables")
llm = ChatGroq(api_key=GROQ_API_KEY, model='llama-3.3-70b-versatile')
prompt = ChatPromptTemplate.from_messages([
("system", "You are a helpful and knowledgeable biology tutor. Answer clearly and accurately. If the query is out of syllabus, just respond with 'Out of syllabus'."),
("human", "Context:\n{context}\n\nQuestion: {question}")
])
memory = ConversationBufferWindowMemory(
memory_key="chat_history",
return_messages=True,
k=3
)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
def rerank_documents(query: str, retrieved_docs: List[Document]) -> List[Document]:
docs_texts = [doc.page_content for doc in retrieved_docs]
pairs = [(query, doc_text) for doc_text in docs_texts]
scores = reranker.predict(pairs)
sorted_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), key=lambda x: x[0], reverse=True)]
return sorted_docs
class RerankRetriever(BaseRetriever, BaseModel):
base_retriever: BaseRetriever
top_k: int = 5
def _get_relevant_documents(self, query: str) -> List[Document]:
initial_docs = self.base_retriever.invoke(query)
reranked_docs = rerank_documents(query, initial_docs)
return reranked_docs[:self.top_k]
base_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
custom_retriever = RerankRetriever(base_retriever=base_retriever, top_k=5)
qa_chain = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=custom_retriever,
memory=memory,
combine_docs_chain_kwargs={"prompt": prompt}
)
class QuestionInput(BaseModel):
question: str
@app.post("/predict")
def predict(input: QuestionInput):
result = qa_chain({"question": input.question})
return {"answer": result["answer"]}
if __name__ == "__main__":
uvicorn.run(app, host='0.0.0.0', port=2000)