File size: 3,230 Bytes
ca767c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
from dotenv import load_dotenv
import os
import uvicorn

load_dotenv()

app = FastAPI(title="A RAG-Driven Learning Assistant for Biology")

from langchain_community.document_loaders import DirectoryLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.schema import Document, BaseRetriever
from sentence_transformers import CrossEncoder
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferWindowMemory
from langchain.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq

loader = DirectoryLoader('data/', glob="**/*.pdf", show_progress=True, loader_cls=PyPDFLoader)
documents = loader.load()

splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)
chunks = splitter.split_documents(documents)

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vectorstore = FAISS.from_documents(chunks, embeddings)

GROQ_API_KEY = os.getenv('GROQ_API_KEY')
if not GROQ_API_KEY:
    raise ValueError("GROQ_API_KEY is not set in the environment variables")
    
llm = ChatGroq(api_key=GROQ_API_KEY, model='llama-3.3-70b-versatile')

prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a helpful and knowledgeable biology tutor. Answer clearly and accurately. If the query is out of syllabus, just respond with 'Out of syllabus'."), 
    ("human", "Context:\n{context}\n\nQuestion: {question}")
])

memory = ConversationBufferWindowMemory(
    memory_key="chat_history", 
    return_messages=True, 
    k=3
)

reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

def rerank_documents(query: str, retrieved_docs: List[Document]) -> List[Document]:
    docs_texts = [doc.page_content for doc in retrieved_docs]
    pairs = [(query, doc_text) for doc_text in docs_texts]
    scores = reranker.predict(pairs)
    sorted_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), key=lambda x: x[0], reverse=True)]
    return sorted_docs

class RerankRetriever(BaseRetriever, BaseModel):
    base_retriever: BaseRetriever
    top_k: int = 5

    def _get_relevant_documents(self, query: str) -> List[Document]:
        initial_docs = self.base_retriever.invoke(query)
        reranked_docs = rerank_documents(query, initial_docs)
        return reranked_docs[:self.top_k]

base_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
custom_retriever = RerankRetriever(base_retriever=base_retriever, top_k=5)

qa_chain = ConversationalRetrievalChain.from_llm(
    llm=llm,
    retriever=custom_retriever,
    memory=memory,
    combine_docs_chain_kwargs={"prompt": prompt}
)

class QuestionInput(BaseModel):
    question: str

@app.post("/predict")
def predict(input: QuestionInput):
    result = qa_chain({"question": input.question})  
    return {"answer": result["answer"]}

if __name__ == "__main__":
    uvicorn.run(app, host='0.0.0.0', port=2000)