File size: 8,782 Bytes
92e7a8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b73e3cb
92e7a8f
 
b73e3cb
92e7a8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b73e3cb
92e7a8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b73e3cb
92e7a8f
f512ef0
92e7a8f
 
 
b73e3cb
 
92e7a8f
b73e3cb
92e7a8f
 
 
 
 
b73e3cb
92e7a8f
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
import os
import nltk
import logging
import json
import numpy as np
from sklearn.cluster import KMeans
nltk.data.path.append("/app/nltk_data")

os.environ["HF_HOME"] = "/app/cache"
os.environ["XDG_CACHE_HOME"] = "/app/cache"
os.environ["TMPDIR"] = "/app/tmp"

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
from langchain_core.output_parsers import StrOutputParser
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.documents import Document

app = FastAPI(title="TechChat Rag")

class Question(BaseModel):
    query: str

def semantic_chunk_with_embeddings(documents, embeddings, max_chunk_size=1000, min_sentences=2, overlap_sentences=1):
    """Chunk documents into semantically related groups using embeddings and clustering."""
    all_chunks = []
    for doc in documents:
        sentences = nltk.sent_tokenize(doc.page_content)
        if len(sentences) < min_sentences:
            all_chunks.append(Document(page_content=" ".join(sentences), metadata=doc.metadata))
            continue

        # Generate embeddings for each sentence
        sentence_embeddings = embeddings.embed_documents(sentences)
        sentence_embeddings = np.array(sentence_embeddings)

        # Cluster sentences using KMeans (dynamically determine num clusters)
        num_clusters = max(1, min(len(sentences) // min_sentences, 10))  # Cap at 10 clusters
        kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(sentence_embeddings)
        labels = kmeans.labels_

        # Group sentences by cluster
        clusters = {}
        for sentence, label in zip(sentences, labels):
            if label not in clusters:
                clusters[label] = []
            clusters[label].append(sentence)

        # Form chunks from clusters with overlap
        for cluster_id, cluster_sentences in clusters.items():
            current_chunk = ""
            chunk_sentences = []
            for i, sentence in enumerate(cluster_sentences):
                if len(current_chunk) + len(sentence) < max_chunk_size:
                    current_chunk += sentence + " "
                    chunk_sentences.append(sentence)
                else:
                    all_chunks.append(Document(page_content=current_chunk.strip(), metadata=doc.metadata))
                    # Add overlap
                    overlap = " ".join(chunk_sentences[-overlap_sentences:]) + " "
                    current_chunk = overlap + sentence + " "
                    chunk_sentences = chunk_sentences[-overlap_sentences:] + [sentence]
            if current_chunk:
                all_chunks.append(Document(page_content=current_chunk.strip(), metadata=doc.metadata))

    return all_chunks

def load_rag_system():
    logger.info("Loading Gemini model...")
    try:
        llm = ChatGoogleGenerativeAI(
            model="gemini-2.0-flash",
            google_api_key=os.getenv("GOOGLE_API_KEY"),
            temperature=0.3,
            top_p=0.9,
            max_tokens=1024
        )
    except Exception as e:
        logger.error(f"Gemini loading failed: {str(e)}")
        raise

    # Load embeddings for chunking and retrieval
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

    # Load PDF
    logger.info("Loading PDF...")
    pdf_paths = ["Admission_Requirement.pdf", "POSTGRADUATE_ADMISSIONS.pdf"]
    pages = []
    for path in pdf_paths:
        if os.path.exists(path):
            loader = PyPDFLoader(path)
            pages.extend(loader.load())
        else:
            logger.warning(f"PDF not found: {path}")

    pdf_docs = semantic_chunk_with_embeddings(pages, embeddings)

    # Load JSONL
    logger.info("Loading JSONL data...")
    jsonl_paths = ["cleaned-dataset.jsonl"]
    jsonl_docs = []
    for path in jsonl_paths:
        if os.path.exists(path):
            with open(path, "r") as f:
                for line in f:
                    data = json.loads(line.strip())
                    content = f"Instruction: {data['instruction']}\nResponse: {data['response']}"
                    jsonl_docs.append(Document(page_content=content, metadata={"source": "jsonl", "instruction": data["instruction"]}))
        else:
            logger.warning(f"JSONL not found: {path}")

    # Combine documents
    all_docs = pdf_docs + jsonl_docs
    unique_docs = {doc.page_content: doc for doc in all_docs}.values()
    for i, doc in enumerate(unique_docs):
        doc.metadata["doc_id"] = i

    logger.info("Building vector store...")
    vectorstore = FAISS.from_documents(list(unique_docs), embedding=embeddings)
    faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 4})  # Increased to 4
    bm25_retriever = BM25Retriever.from_documents(list(unique_docs))
    bm25_retriever.k = 4  # Increased to 4
    retriever = EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5])

    template = """
    You are TechChat, an AI assistant created to provide accurate, concise, and helpful information about admissions to Kwame Nkrumah University of Science and Technology (KNUST). Your primary goal is to assist users with questions related to KNUST admissions, including application processes, requirements, deadlines, programs, and other relevant details.

    ### Instructions:
    1. **KNUST Admissions Questions**: Use the provided context as a guide to answer questions about KNUST admissions clearly and accurately but you are not mandated to stick to the context if it is inacurate. You are to refine the context for response. If the context is sufficient, tailor your response to the specific details provided.
    2. **Limited Context**: If the context lacks enough information to fully answer a KNUST admissions question, provide a general but accurate response based on your knowledge of KNUST admissions, and invite the user to provide more details for a more specific answer.
    3. **Off-Topic Questions**: If the question is unrelated to KNUST admissions, respond politely with: "I'm sorry, that question is outside my focus on KNUST admissions. Feel free to ask about KNUST application processes, requirements, or programs, and I'll be happy to help!"
    4. **Tone and Style**: Maintain a friendly, professional, and approachable tone. Avoid overly technical jargon unless necessary, and ensure responses are easy to understand.
    5. **No Assumptions**: Do not invent information. If you cannot answer due to missing or unclear information, acknowledge it and encourage the user to clarify.

    ### Context:
    {context}

    ### Question:
    {question}

    ### Answer:
    """
    
    prompt = PromptTemplate.from_template(template)
    parser = StrOutputParser()
    chain = LLMChain(llm=llm, prompt=prompt, output_parser=parser)

    return retriever, vectorstore, chain

logger.info("Initializing RAG system...")
retriever, vectorstore, chain = load_rag_system()



@app.post("/chat")
async def ask(question: Question):
    print('ask route reached')
    try:
        logger.info(f"Received question: {question.query}")
        context_docs = retriever.invoke(question.query)
        logger.info(f"Retrieved {len(context_docs)} context documents")
        max_similarity = max([vectorstore.similarity_search_with_score(question.query, k=1)[0][1] for _ in context_docs], default=0)
        if max_similarity < 0.25:  # Lowered threshold slightly
            logger.info("Similarity too low, returning 'I don’t know'")
            return {"answer": "I'm not sure about that, but I'd be happy to help if you provide more details!"}
        context_text = "\n".join([doc.page_content for doc in context_docs])
        logger.info("Generating response...")
        response = chain.invoke({"context": context_text, "question": question.query})
        answer = response['text'].split("Answer:")[-1].strip()
        logger.info(f"Generated answer: {answer}")
        return {"answer": answer}
    except Exception as e:
        logger.error(f"Error processing question: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/")
async def root():
    print('HR Policy Bot is actively running!')
    logger.info("Root endpoint accessed")
    return {"message": "HR Policy Bot is running!"}