Spaces:
Runtime error
Runtime error
| import os | |
| import nltk | |
| import logging | |
| import json | |
| import numpy as np | |
| from sklearn.cluster import KMeans | |
| nltk.data.path.append("/app/nltk_data") | |
| os.environ["HF_HOME"] = "/app/cache" | |
| os.environ["XDG_CACHE_HOME"] = "/app/cache" | |
| os.environ["TMPDIR"] = "/app/tmp" | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain.prompts import PromptTemplate | |
| from langchain.chains import LLMChain | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from langchain_core.documents import Document | |
| app = FastAPI(title="TechChat Rag") | |
| class Question(BaseModel): | |
| query: str | |
| def semantic_chunk_with_embeddings(documents, embeddings, max_chunk_size=1000, min_sentences=2, overlap_sentences=1): | |
| """Chunk documents into semantically related groups using embeddings and clustering.""" | |
| all_chunks = [] | |
| for doc in documents: | |
| sentences = nltk.sent_tokenize(doc.page_content) | |
| if len(sentences) < min_sentences: | |
| all_chunks.append(Document(page_content=" ".join(sentences), metadata=doc.metadata)) | |
| continue | |
| # Generate embeddings for each sentence | |
| sentence_embeddings = embeddings.embed_documents(sentences) | |
| sentence_embeddings = np.array(sentence_embeddings) | |
| # Cluster sentences using KMeans (dynamically determine num clusters) | |
| num_clusters = max(1, min(len(sentences) // min_sentences, 10)) # Cap at 10 clusters | |
| kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(sentence_embeddings) | |
| labels = kmeans.labels_ | |
| # Group sentences by cluster | |
| clusters = {} | |
| for sentence, label in zip(sentences, labels): | |
| if label not in clusters: | |
| clusters[label] = [] | |
| clusters[label].append(sentence) | |
| # Form chunks from clusters with overlap | |
| for cluster_id, cluster_sentences in clusters.items(): | |
| current_chunk = "" | |
| chunk_sentences = [] | |
| for i, sentence in enumerate(cluster_sentences): | |
| if len(current_chunk) + len(sentence) < max_chunk_size: | |
| current_chunk += sentence + " " | |
| chunk_sentences.append(sentence) | |
| else: | |
| all_chunks.append(Document(page_content=current_chunk.strip(), metadata=doc.metadata)) | |
| # Add overlap | |
| overlap = " ".join(chunk_sentences[-overlap_sentences:]) + " " | |
| current_chunk = overlap + sentence + " " | |
| chunk_sentences = chunk_sentences[-overlap_sentences:] + [sentence] | |
| if current_chunk: | |
| all_chunks.append(Document(page_content=current_chunk.strip(), metadata=doc.metadata)) | |
| return all_chunks | |
| def load_rag_system(): | |
| logger.info("Loading Gemini model...") | |
| try: | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.0-flash", | |
| google_api_key=os.getenv("GOOGLE_API_KEY"), | |
| temperature=0.3, | |
| top_p=0.9, | |
| max_tokens=1024 | |
| ) | |
| except Exception as e: | |
| logger.error(f"Gemini loading failed: {str(e)}") | |
| raise | |
| # Load embeddings for chunking and retrieval | |
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
| # Load PDF | |
| logger.info("Loading PDF...") | |
| pdf_paths = ["Admission_Requirement.pdf", "POSTGRADUATE_ADMISSIONS.pdf"] | |
| pages = [] | |
| for path in pdf_paths: | |
| if os.path.exists(path): | |
| loader = PyPDFLoader(path) | |
| pages.extend(loader.load()) | |
| else: | |
| logger.warning(f"PDF not found: {path}") | |
| pdf_docs = semantic_chunk_with_embeddings(pages, embeddings) | |
| # Load JSONL | |
| logger.info("Loading JSONL data...") | |
| jsonl_paths = ["cleaned-dataset.jsonl"] | |
| jsonl_docs = [] | |
| for path in jsonl_paths: | |
| if os.path.exists(path): | |
| with open(path, "r") as f: | |
| for line in f: | |
| data = json.loads(line.strip()) | |
| content = f"Instruction: {data['instruction']}\nResponse: {data['response']}" | |
| jsonl_docs.append(Document(page_content=content, metadata={"source": "jsonl", "instruction": data["instruction"]})) | |
| else: | |
| logger.warning(f"JSONL not found: {path}") | |
| # Combine documents | |
| all_docs = pdf_docs + jsonl_docs | |
| unique_docs = {doc.page_content: doc for doc in all_docs}.values() | |
| for i, doc in enumerate(unique_docs): | |
| doc.metadata["doc_id"] = i | |
| logger.info("Building vector store...") | |
| vectorstore = FAISS.from_documents(list(unique_docs), embedding=embeddings) | |
| faiss_retriever = vectorstore.as_retriever(search_kwargs={"k": 4}) # Increased to 4 | |
| bm25_retriever = BM25Retriever.from_documents(list(unique_docs)) | |
| bm25_retriever.k = 4 # Increased to 4 | |
| retriever = EnsembleRetriever(retrievers=[faiss_retriever, bm25_retriever], weights=[0.5, 0.5]) | |
| template = """ | |
| You are TechChat, an AI assistant created to provide accurate, concise, and helpful information about admissions to Kwame Nkrumah University of Science and Technology (KNUST). Your primary goal is to assist users with questions related to KNUST admissions, including application processes, requirements, deadlines, programs, and other relevant details. | |
| ### Instructions: | |
| 1. **KNUST Admissions Questions**: Use the provided context as a guide to answer questions about KNUST admissions clearly and accurately but you are not mandated to stick to the context if it is inacurate. You are to refine the context for response. If the context is sufficient, tailor your response to the specific details provided. | |
| 2. **Limited Context**: If the context lacks enough information to fully answer a KNUST admissions question, provide a general but accurate response based on your knowledge of KNUST admissions, and invite the user to provide more details for a more specific answer. | |
| 3. **Off-Topic Questions**: If the question is unrelated to KNUST admissions, respond politely with: "I'm sorry, that question is outside my focus on KNUST admissions. Feel free to ask about KNUST application processes, requirements, or programs, and I'll be happy to help!" | |
| 4. **Tone and Style**: Maintain a friendly, professional, and approachable tone. Avoid overly technical jargon unless necessary, and ensure responses are easy to understand. | |
| 5. **No Assumptions**: Do not invent information. If you cannot answer due to missing or unclear information, acknowledge it and encourage the user to clarify. | |
| ### Context: | |
| {context} | |
| ### Question: | |
| {question} | |
| ### Answer: | |
| """ | |
| prompt = PromptTemplate.from_template(template) | |
| parser = StrOutputParser() | |
| chain = LLMChain(llm=llm, prompt=prompt, output_parser=parser) | |
| return retriever, vectorstore, chain | |
| logger.info("Initializing RAG system...") | |
| retriever, vectorstore, chain = load_rag_system() | |
| async def ask(question: Question): | |
| print('ask route reached') | |
| try: | |
| logger.info(f"Received question: {question.query}") | |
| context_docs = retriever.invoke(question.query) | |
| logger.info(f"Retrieved {len(context_docs)} context documents") | |
| max_similarity = max([vectorstore.similarity_search_with_score(question.query, k=1)[0][1] for _ in context_docs], default=0) | |
| if max_similarity < 0.25: # Lowered threshold slightly | |
| logger.info("Similarity too low, returning 'I don’t know'") | |
| return {"answer": "I'm not sure about that, but I'd be happy to help if you provide more details!"} | |
| context_text = "\n".join([doc.page_content for doc in context_docs]) | |
| logger.info("Generating response...") | |
| response = chain.invoke({"context": context_text, "question": question.query}) | |
| answer = response['text'].split("Answer:")[-1].strip() | |
| logger.info(f"Generated answer: {answer}") | |
| return {"answer": answer} | |
| except Exception as e: | |
| logger.error(f"Error processing question: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def root(): | |
| print('HR Policy Bot is actively running!') | |
| logger.info("Root endpoint accessed") | |
| return {"message": "HR Policy Bot is running!"} | |