from flask import Flask, request, jsonify from flask_cors import CORS import os import re import pandas as pd from typing import List from deep_translator import GoogleTranslator from langchain_core.documents import Document from langchain_community.document_loaders import ( WebBaseLoader, PyPDFLoader, Docx2txtLoader ) from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import Chroma from langchain.chains import RetrievalQA from langchain.prompts import PromptTemplate from langchain.llms import HuggingFacePipeline from transformers import pipeline from huggingface_hub import snapshot_download # ============================================================ # APP INIT # ============================================================ app = Flask(__name__) CORS(app) # ============================================================ # CONSTANTS # ============================================================ DATASET_REPO = "bshk57/Sastra_data" LOCAL_DATASET_DIR = "knowledge_base" VECTOR_DB_PATH = "sastra_vector_db" EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" LLM_MODEL = "google/flan-t5-large" MANDATORY_URLS = [ "https://www.sastra.edu/admissions/ug-pg.html", "https://www.sastra.edu/admissions/eligibility-criteria.html", "https://www.sastra.edu/admissions/fee-structure.html", ] SASTRA_URLS = [ "https://www.sastra.edu/about-us.html", "https://www.sastra.edu/academics/schools.html#school-of-computing", "https://www.sastra.edu/admissions/ug-pg.html", "https://www.sastra.edu/admissions/eligibility-criteria.html", "https://www.sastra.edu/admissions/fee-structure.html", "https://www.sastra.edu/admissions/hostel-fees.html", "https://www.sastra.edu/infrastructure/physical-facilities.html", "https://www.sastra.edu/about-us/mission-vision.html", ] os.makedirs(LOCAL_DATASET_DIR, exist_ok=True) # ============================================================ # UTILITIES # ============================================================ def extract_urls(text: str): return re.findall(r'https?://[^\s]+', text) def clean_llm_output(text: str): text = re.sub(r'^(Answer:|Response:)', '', text, flags=re.I).strip() if text.lower().startswith("insufficient_data"): return "" return re.sub(r'\s+', ' ', text)[:700] # ============================================================ # LOAD DATASET FROM HUGGINGFACE # ============================================================ print("⬇ Downloading dataset...") snapshot_download( repo_id=DATASET_REPO, repo_type="dataset", local_dir=LOCAL_DATASET_DIR, local_dir_use_symlinks=False, ignore_patterns=[".gitattributes"] ) # ============================================================ # LOAD LOCAL DOCUMENTS # ============================================================ def load_local_documents(): docs = [] for file in os.listdir(LOCAL_DATASET_DIR): path = os.path.join(LOCAL_DATASET_DIR, file) try: if file.lower().endswith(".pdf"): docs.extend(PyPDFLoader(path).load()) elif file.lower().endswith(".docx"): docs.extend(Docx2txtLoader(path).load()) elif file.lower().endswith(".xlsx"): df = pd.read_excel(path) for _, row in df.iterrows(): text = " | ".join( f"{col}: {row[col]}" for col in df.columns if pd.notna(row[col]) ) docs.append(Document( page_content=text, metadata={"source": file} )) except Exception as e: print(f"⚠ Error loading {file}: {e}") return docs # ============================================================ # INITIALIZE RAG MODEL # ============================================================ def initialize_model(): global vectordb, qa_chain documents = [] # Website data for url in SASTRA_URLS: try: documents.extend(WebBaseLoader(url).load()) except: pass # Dataset + uploaded docs documents.extend(load_local_documents()) splitter = RecursiveCharacterTextSplitter( chunk_size=600, chunk_overlap=50 ) chunks = splitter.split_documents(documents) embeddings = HuggingFaceEmbeddings( model_name=EMBEDDING_MODEL ) vectordb = Chroma.from_documents( chunks, embeddings, persist_directory=VECTOR_DB_PATH ) retriever = vectordb.as_retriever(search_kwargs={"k": 4}) generator = pipeline( "text2text-generation", model=LLM_MODEL, tokenizer=LLM_MODEL, max_new_tokens=200, temperature=0.1, top_p=0.9, repetition_penalty=1.2 ) llm = HuggingFacePipeline(pipeline=generator) prompt = PromptTemplate( input_variables=["context", "question"], template=""" You are AskSASTRA, the official SASTRA University assistant. Answer strictly from the given context. If information is missing, say INSUFFICIENT_DATA. Context: {context} Question: {question} Answer: """ ) qa_chain = RetrievalQA.from_chain_type( llm=llm, retriever=retriever, chain_type="stuff", chain_type_kwargs={"prompt": prompt}, return_source_documents=False ) print("✅ AskSASTRA initialized") initialize_model() # ============================================================ # ROUTES # ============================================================ @app.route("/", methods=["GET"]) def home(): return jsonify({"message": "AskSASTRA API running 🚀"}) @app.route("/chat", methods=["POST"]) def chat(): data = request.get_json(force=True) query = data.get("query", "").strip() lang = data.get("language", "en") if not query: return jsonify({"answer": "Please ask a valid question."}) # Translate input → English if lang != "en": try: query_en = GoogleTranslator(source=lang, target="en").translate(query) except: query_en = query else: query_en = query # RAG inference try: result = qa_chain.invoke({"query": query_en}) answer_en = clean_llm_output(result.get("result", "")) except: answer_en = "" if not answer_en: answer_en = ( "I could not find confident information for this question. " "Please visit the official SASTRA website." ) # URL handling extracted_urls = extract_urls(answer_en) all_urls = list(set(extracted_urls + MANDATORY_URLS)) # Translate output back if lang != "en": try: final_answer = GoogleTranslator(source="en", target=lang).translate(answer_en) except: final_answer = answer_en else: final_answer = answer_en return jsonify({ "answer": final_answer, "urls": all_urls }) @app.route("/retrain", methods=["POST"]) def retrain(): data = request.get_json(force=True) context = data.get("context", "").strip() if not context: return jsonify({"error": "Context text required"}), 400 doc = Document( page_content=context, metadata={"source": "user_retrain"} ) splitter = RecursiveCharacterTextSplitter( chunk_size=600, chunk_overlap=50 ) chunks = splitter.split_documents([doc]) vectordb.add_documents(chunks) vectordb.persist() initialize_model() return jsonify({ "status": "success", "message": "Successfully retrained with your data" })