tech-chat-rag / app.py
pfrimpong's picture
Update app.py
b73e3cb verified
import os
import nltk
import logging
import json
import numpy as np
from sklearn.cluster import KMeans
nltk.data.path.append("/app/nltk_data")
os.environ["HF_HOME"] = "/app/cache"
os.environ["XDG_CACHE_HOME"] = "/app/cache"
os.environ["TMPDIR"] = "/app/tmp"
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_core.output_parsers import StrOutputParser
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.documents import Document
app = FastAPI(title="TechChat Rag")
class Question(BaseModel):
query: str
def semantic_chunk_with_embeddings(documents, embeddings, max_chunk_size=1000, min_sentences=2, overlap_sentences=1):
"""Chunk documents into semantically related groups using embeddings and clustering."""
all_chunks = []
for doc in documents:
sentences = nltk.sent_tokenize(doc.page_content)
if len(sentences) < min_sentences:
all_chunks.append(Document(page_content=" ".join(sentences), metadata=doc.metadata))
continue
# Generate embeddings for each sentence
sentence_embeddings = embeddings.embed_documents(sentences)
sentence_embeddings = np.array(sentence_embeddings)
# Cluster sentences using KMeans (dynamically determine num clusters)
num_clusters = max(1, min(len(sentences) // min_sentences, 10)) # Cap at 10 clusters
kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(sentence_embeddings)
labels = kmeans.labels_
# Group sentences by cluster
clusters = {}
for sentence, label in zip(sentences, labels):
if label not in clusters:
clusters[label] = []
clusters[label].append(sentence)
# Form chunks from clusters with overlap
for cluster_id, cluster_sentences in clusters.items():
current_chunk = ""
chunk_sentences = []
for i, sentence in enumerate(cluster_sentences):
if len(current_chunk) + len(sentence) < max_chunk_size:
current_chunk += sentence + " "
chunk_sentences.append(sentence)
else:
all_chunks.append(Document(page_content=current_chunk.strip(), metadata=doc.metadata))
# Add overlap
overlap = " ".join(chunk_sentences[-overlap_sentences:]) + " "
current_chunk = overlap + sentence + " "
chunk_sentences = chunk_sentences[-overlap_sentences:] + [sentence]
if current_chunk:
all_chunks.append(Document(page_content=current_chunk.strip(), metadata=doc.metadata))
return all_chunks
def load_rag_system():
logger.info("Loading Gemini model...")
try:
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
google_api_key=os.getenv("GOOGLE_API_KEY"),
temperature=0.3,
top_p=0.9,
max_tokens=1024
)
except Exception as e:
logger.error(f"Gemini loading failed: {str(e)}")
raise
# Load embeddings for chunking and retrieval
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
# Load PDF
logger.info("Loading PDF...")
pdf_paths = ["Admission_Requirement.pdf", "POSTGRADUATE_ADMISSIONS.pdf"]
pages = []
for path in pdf_paths:
if os.path.exists(path):
loader = PyPDFLoader(path)
pages.extend(loader.load())
else:
logger.warning(f"PDF not found: {path}")
pdf_docs = semantic_chunk_with_embeddings(pages, embeddings)
# Load JSONL
logger.info("Loading JSONL data...")
jsonl_paths = ["cleaned-dataset.jsonl"]
jsonl_docs = []
for path in jsonl_paths:
if os.path.exists(path):
with open(path, "r") as f:
for line in f:
data = json.loads(line.strip())
content = f"Instruction: {data['instruction']}\nResponse: {data['response']}"
jsonl_docs.append(Document(page_content=content, metadata={"source": "jsonl", "instruction": data["instruction"]}))
else:
logger.warning(f"JSONL not found: {path}")
# Combine documents
all_docs = pdf_docs + jsonl_docs
unique_docs = {doc.page_content: doc for doc in all_docs}.values()
for i, doc in enumerate(unique_docs):
doc.metadata["doc_id"] = i
logger.info("Building vector store...")
vectorstore = FAISS.from_documents(list(unique_docs), embedding=embeddings)
faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) # Increased to 4
bm25_retriever = BM25Retriever.from_documents(list(unique_docs))
bm25_retriever.k = 4 # Increased to 4
retriever = EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5])
template = """
You are TechChat, an AI assistant created to provide accurate, concise, and helpful information about admissions to Kwame Nkrumah University of Science and Technology (KNUST). Your primary goal is to assist users with questions related to KNUST admissions, including application processes, requirements, deadlines, programs, and other relevant details.
### Instructions:
1. **KNUST Admissions Questions**: Use the provided context as a guide to answer questions about KNUST admissions clearly and accurately but you are not mandated to stick to the context if it is inacurate. You are to refine the context for response. If the context is sufficient, tailor your response to the specific details provided.
2. **Limited Context**: If the context lacks enough information to fully answer a KNUST admissions question, provide a general but accurate response based on your knowledge of KNUST admissions, and invite the user to provide more details for a more specific answer.
3. **Off-Topic Questions**: If the question is unrelated to KNUST admissions, respond politely with: "I'm sorry, that question is outside my focus on KNUST admissions. Feel free to ask about KNUST application processes, requirements, or programs, and I'll be happy to help!"
4. **Tone and Style**: Maintain a friendly, professional, and approachable tone. Avoid overly technical jargon unless necessary, and ensure responses are easy to understand.
5. **No Assumptions**: Do not invent information. If you cannot answer due to missing or unclear information, acknowledge it and encourage the user to clarify.
### Context:
{context}
### Question:
{question}
### Answer:
"""
prompt = PromptTemplate.from_template(template)
parser = StrOutputParser()
chain = LLMChain(llm=llm, prompt=prompt, output_parser=parser)
return retriever, vectorstore, chain
logger.info("Initializing RAG system...")
retriever, vectorstore, chain = load_rag_system()
@app.post("/chat")
async def ask(question: Question):
print('ask route reached')
try:
logger.info(f"Received question: {question.query}")
context_docs = retriever.invoke(question.query)
logger.info(f"Retrieved {len(context_docs)} context documents")
max_similarity = max([vectorstore.similarity_search_with_score(question.query, k=1)[0][1] for _ in context_docs], default=0)
if max_similarity < 0.25: # Lowered threshold slightly
logger.info("Similarity too low, returning 'I don’t know'")
return {"answer": "I'm not sure about that, but I'd be happy to help if you provide more details!"}
context_text = "\n".join([doc.page_content for doc in context_docs])
logger.info("Generating response...")
response = chain.invoke({"context": context_text, "question": question.query})
answer = response['text'].split("Answer:")[-1].strip()
logger.info(f"Generated answer: {answer}")
return {"answer": answer}
except Exception as e:
logger.error(f"Error processing question: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/")
async def root():
print('HR Policy Bot is actively running!')
logger.info("Root endpoint accessed")
return {"message": "HR Policy Bot is running!"}