| | from flask import Flask, request, jsonify |
| | from flask_cors import CORS |
| | import os |
| | import re |
| | import pandas as pd |
| | from typing import List |
| | from deep_translator import GoogleTranslator |
| |
|
| | from langchain_core.documents import Document |
| | from langchain_community.document_loaders import ( |
| | WebBaseLoader, PyPDFLoader, Docx2txtLoader |
| | ) |
| | from langchain.text_splitter import RecursiveCharacterTextSplitter |
| | from langchain_community.embeddings import HuggingFaceEmbeddings |
| | from langchain_community.vectorstores import Chroma |
| | from langchain.chains import RetrievalQA |
| | from langchain.prompts import PromptTemplate |
| | from langchain.llms import HuggingFacePipeline |
| | from transformers import pipeline |
| |
|
| | from huggingface_hub import snapshot_download |
| |
|
| | |
| | |
| | |
| |
|
| | app = Flask(__name__) |
| | CORS(app) |
| |
|
| | |
| | |
| | |
| |
|
| | DATASET_REPO = "bshk57/Sastra_data" |
| | LOCAL_DATASET_DIR = "knowledge_base" |
| | VECTOR_DB_PATH = "sastra_vector_db" |
| |
|
| | EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" |
| | LLM_MODEL = "google/flan-t5-base" |
| |
|
| | MANDATORY_URLS = [ |
| | "https://www.sastra.edu/admissions/ug-pg.html", |
| | "https://www.sastra.edu/admissions/eligibility-criteria.html", |
| | "https://www.sastra.edu/admissions/fee-structure.html", |
| | ] |
| |
|
| | SASTRA_URLS = [ |
| | "https://www.sastra.edu/about-us.html", |
| | "https://www.sastra.edu/academics/schools.html#school-of-computing", |
| | "https://www.sastra.edu/admissions/ug-pg.html", |
| | "https://www.sastra.edu/admissions/eligibility-criteria.html", |
| | "https://www.sastra.edu/admissions/fee-structure.html", |
| | "https://www.sastra.edu/admissions/hostel-fees.html", |
| | "https://www.sastra.edu/infrastructure/physical-facilities.html", |
| | "https://www.sastra.edu/about-us/mission-vision.html", |
| | ] |
| |
|
| | os.makedirs(LOCAL_DATASET_DIR, exist_ok=True) |
| |
|
| | |
| | |
| | |
| |
|
| | def extract_urls(text: str): |
| | return re.findall(r'https?://[^\s]+', text) |
| |
|
| |
|
| | def clean_llm_output(text: str): |
| | text = re.sub(r'^(Answer:|Response:)', '', text, flags=re.I).strip() |
| | if text.lower().startswith("insufficient_data"): |
| | return "" |
| | return re.sub(r'\s+', ' ', text)[:700] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | print("⬇ Downloading dataset...") |
| | snapshot_download( |
| | repo_id=DATASET_REPO, |
| | repo_type="dataset", |
| | local_dir=LOCAL_DATASET_DIR, |
| | local_dir_use_symlinks=False, |
| | ignore_patterns=[".gitattributes"] |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | def load_local_documents(): |
| | docs = [] |
| |
|
| | for file in os.listdir(LOCAL_DATASET_DIR): |
| | path = os.path.join(LOCAL_DATASET_DIR, file) |
| | try: |
| | if file.lower().endswith(".pdf"): |
| | docs.extend(PyPDFLoader(path).load()) |
| |
|
| | elif file.lower().endswith(".docx"): |
| | docs.extend(Docx2txtLoader(path).load()) |
| |
|
| | elif file.lower().endswith(".xlsx"): |
| | df = pd.read_excel(path) |
| | for _, row in df.iterrows(): |
| | text = " | ".join( |
| | f"{col}: {row[col]}" |
| | for col in df.columns |
| | if pd.notna(row[col]) |
| | ) |
| | docs.append(Document( |
| | page_content=text, |
| | metadata={"source": file} |
| | )) |
| | except Exception as e: |
| | print(f"⚠ Error loading {file}: {e}") |
| |
|
| | return docs |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def initialize_model(): |
| | global vectordb, qa_chain |
| |
|
| | documents = [] |
| |
|
| | |
| | for url in SASTRA_URLS: |
| | try: |
| | documents.extend(WebBaseLoader(url).load()) |
| | except: |
| | pass |
| |
|
| | |
| | documents.extend(load_local_documents()) |
| |
|
| | splitter = RecursiveCharacterTextSplitter( |
| | chunk_size=600, |
| | chunk_overlap=50 |
| | ) |
| | chunks = splitter.split_documents(documents) |
| |
|
| | embeddings = HuggingFaceEmbeddings( |
| | model_name=EMBEDDING_MODEL |
| | ) |
| |
|
| | vectordb = Chroma.from_documents( |
| | chunks, |
| | embeddings, |
| | persist_directory=VECTOR_DB_PATH |
| | ) |
| |
|
| | retriever = vectordb.as_retriever(search_kwargs={"k": 4}) |
| |
|
| | generator = pipeline( |
| | "text2text-generation", |
| | model=LLM_MODEL, |
| | tokenizer=LLM_MODEL, |
| | max_new_tokens=200, |
| | temperature=0.1, |
| | top_p=0.9, |
| | repetition_penalty=1.2 |
| | ) |
| |
|
| | llm = HuggingFacePipeline(pipeline=generator) |
| |
|
| | prompt = PromptTemplate( |
| | input_variables=["context", "question"], |
| | template=""" |
| | You are AskSASTRA, the official SASTRA University assistant. |
| | Answer strictly from the given context. |
| | If information is missing, say INSUFFICIENT_DATA. |
| | |
| | Context: |
| | {context} |
| | |
| | Question: |
| | {question} |
| | |
| | Answer: |
| | """ |
| | ) |
| |
|
| | qa_chain = RetrievalQA.from_chain_type( |
| | llm=llm, |
| | retriever=retriever, |
| | chain_type="stuff", |
| | chain_type_kwargs={"prompt": prompt}, |
| | return_source_documents=False |
| | ) |
| |
|
| | print("✅ AskSASTRA initialized") |
| |
|
| |
|
| | initialize_model() |
| |
|
| | |
| | |
| | |
| |
|
| | @app.route("/", methods=["GET"]) |
| | def home(): |
| | return jsonify({"message": "AskSASTRA API running 🚀"}) |
| |
|
| |
|
| | @app.route("/chat", methods=["POST"]) |
| | def chat(): |
| | data = request.get_json(force=True) |
| | query = data.get("query", "").strip() |
| | lang = data.get("language", "en") |
| |
|
| | if not query: |
| | return jsonify({"answer": "Please ask a valid question."}) |
| |
|
| | |
| | if lang != "en": |
| | try: |
| | query_en = GoogleTranslator(source=lang, target="en").translate(query) |
| | except: |
| | query_en = query |
| | else: |
| | query_en = query |
| |
|
| | |
| | try: |
| | result = qa_chain.invoke({"query": query_en}) |
| | answer_en = clean_llm_output(result.get("result", "")) |
| | except: |
| | answer_en = "" |
| |
|
| | if not answer_en: |
| | answer_en = ( |
| | "I could not find confident information for this question. " |
| | "Please visit the official SASTRA website." |
| | ) |
| |
|
| | |
| | extracted_urls = extract_urls(answer_en) |
| | all_urls = list(set(extracted_urls + MANDATORY_URLS)) |
| |
|
| | |
| | if lang != "en": |
| | try: |
| | final_answer = GoogleTranslator(source="en", target=lang).translate(answer_en) |
| | except: |
| | final_answer = answer_en |
| | else: |
| | final_answer = answer_en |
| |
|
| | return jsonify({ |
| | "answer": final_answer, |
| | "urls": all_urls |
| | }) |
| |
|
| |
|
| | @app.route("/retrain", methods=["POST"]) |
| | def retrain(): |
| | data = request.get_json(force=True) |
| | context = data.get("context", "").strip() |
| |
|
| | if not context: |
| | return jsonify({"error": "Context text required"}), 400 |
| |
|
| | doc = Document( |
| | page_content=context, |
| | metadata={"source": "user_retrain"} |
| | ) |
| |
|
| | splitter = RecursiveCharacterTextSplitter( |
| | chunk_size=600, |
| | chunk_overlap=50 |
| | ) |
| | chunks = splitter.split_documents([doc]) |
| |
|
| | vectordb.add_documents(chunks) |
| | vectordb.persist() |
| |
|
| | initialize_model() |
| |
|
| | return jsonify({ |
| | "status": "success", |
| | "message": "Successfully retrained with your data" |
| | }) |