|
|
from flask import Flask, request, jsonify |
|
|
from flask_cors import CORS |
|
|
import os |
|
|
import re |
|
|
import pandas as pd |
|
|
from typing import List |
|
|
from deep_translator import GoogleTranslator |
|
|
|
|
|
from langchain_core.documents import Document |
|
|
from langchain_community.document_loaders import ( |
|
|
WebBaseLoader, PyPDFLoader, Docx2txtLoader |
|
|
) |
|
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
|
from langchain_community.vectorstores import Chroma |
|
|
from langchain.chains import RetrievalQA |
|
|
from langchain.prompts import PromptTemplate |
|
|
from langchain.llms import HuggingFacePipeline |
|
|
from transformers import pipeline |
|
|
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = Flask(__name__) |
|
|
CORS(app) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DATASET_REPO = "bshk57/Sastra_data" |
|
|
LOCAL_DATASET_DIR = "knowledge_base" |
|
|
VECTOR_DB_PATH = "sastra_vector_db" |
|
|
|
|
|
EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" |
|
|
LLM_MODEL = "google/flan-t5-large" |
|
|
|
|
|
MANDATORY_URLS = [ |
|
|
"https://www.sastra.edu/admissions/ug-pg.html", |
|
|
"https://www.sastra.edu/admissions/eligibility-criteria.html", |
|
|
"https://www.sastra.edu/admissions/fee-structure.html", |
|
|
] |
|
|
|
|
|
SASTRA_URLS = [ |
|
|
"https://www.sastra.edu/about-us.html", |
|
|
"https://www.sastra.edu/academics/schools.html#school-of-computing", |
|
|
"https://www.sastra.edu/admissions/ug-pg.html", |
|
|
"https://www.sastra.edu/admissions/eligibility-criteria.html", |
|
|
"https://www.sastra.edu/admissions/fee-structure.html", |
|
|
"https://www.sastra.edu/admissions/hostel-fees.html", |
|
|
"https://www.sastra.edu/infrastructure/physical-facilities.html", |
|
|
"https://www.sastra.edu/about-us/mission-vision.html", |
|
|
] |
|
|
|
|
|
os.makedirs(LOCAL_DATASET_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_urls(text: str): |
|
|
return re.findall(r'https?://[^\s]+', text) |
|
|
|
|
|
|
|
|
def clean_llm_output(text: str): |
|
|
text = re.sub(r'^(Answer:|Response:)', '', text, flags=re.I).strip() |
|
|
if text.lower().startswith("insufficient_data"): |
|
|
return "" |
|
|
return re.sub(r'\s+', ' ', text)[:700] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("โฌ Downloading dataset...") |
|
|
snapshot_download( |
|
|
repo_id=DATASET_REPO, |
|
|
repo_type="dataset", |
|
|
local_dir=LOCAL_DATASET_DIR, |
|
|
local_dir_use_symlinks=False, |
|
|
ignore_patterns=[".gitattributes"] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_local_documents(): |
|
|
docs = [] |
|
|
|
|
|
for file in os.listdir(LOCAL_DATASET_DIR): |
|
|
path = os.path.join(LOCAL_DATASET_DIR, file) |
|
|
try: |
|
|
if file.lower().endswith(".pdf"): |
|
|
docs.extend(PyPDFLoader(path).load()) |
|
|
|
|
|
elif file.lower().endswith(".docx"): |
|
|
docs.extend(Docx2txtLoader(path).load()) |
|
|
|
|
|
elif file.lower().endswith(".xlsx"): |
|
|
df = pd.read_excel(path) |
|
|
for _, row in df.iterrows(): |
|
|
text = " | ".join( |
|
|
f"{col}: {row[col]}" |
|
|
for col in df.columns |
|
|
if pd.notna(row[col]) |
|
|
) |
|
|
docs.append(Document( |
|
|
page_content=text, |
|
|
metadata={"source": file} |
|
|
)) |
|
|
except Exception as e: |
|
|
print(f"โ Error loading {file}: {e}") |
|
|
|
|
|
return docs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def initialize_model(): |
|
|
global vectordb, qa_chain |
|
|
|
|
|
documents = [] |
|
|
|
|
|
|
|
|
for url in SASTRA_URLS: |
|
|
try: |
|
|
documents.extend(WebBaseLoader(url).load()) |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
documents.extend(load_local_documents()) |
|
|
|
|
|
splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=600, |
|
|
chunk_overlap=50 |
|
|
) |
|
|
chunks = splitter.split_documents(documents) |
|
|
|
|
|
embeddings = HuggingFaceEmbeddings( |
|
|
model_name=EMBEDDING_MODEL |
|
|
) |
|
|
|
|
|
vectordb = Chroma.from_documents( |
|
|
chunks, |
|
|
embeddings, |
|
|
persist_directory=VECTOR_DB_PATH |
|
|
) |
|
|
|
|
|
retriever = vectordb.as_retriever(search_kwargs={"k": 4}) |
|
|
|
|
|
generator = pipeline( |
|
|
"text2text-generation", |
|
|
model=LLM_MODEL, |
|
|
tokenizer=LLM_MODEL, |
|
|
max_new_tokens=200, |
|
|
temperature=0.1, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.2 |
|
|
) |
|
|
|
|
|
llm = HuggingFacePipeline(pipeline=generator) |
|
|
|
|
|
prompt = PromptTemplate( |
|
|
input_variables=["context", "question"], |
|
|
template=""" |
|
|
You are AskSASTRA, the official SASTRA University assistant. |
|
|
Answer strictly from the given context. |
|
|
If information is missing, say INSUFFICIENT_DATA. |
|
|
|
|
|
Context: |
|
|
{context} |
|
|
|
|
|
Question: |
|
|
{question} |
|
|
|
|
|
Answer: |
|
|
""" |
|
|
) |
|
|
|
|
|
qa_chain = RetrievalQA.from_chain_type( |
|
|
llm=llm, |
|
|
retriever=retriever, |
|
|
chain_type="stuff", |
|
|
chain_type_kwargs={"prompt": prompt}, |
|
|
return_source_documents=False |
|
|
) |
|
|
|
|
|
print("โ
AskSASTRA initialized") |
|
|
|
|
|
|
|
|
initialize_model() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.route("/", methods=["GET"]) |
|
|
def home(): |
|
|
return jsonify({"message": "AskSASTRA API running ๐"}) |
|
|
|
|
|
|
|
|
@app.route("/chat", methods=["POST"]) |
|
|
def chat(): |
|
|
data = request.get_json(force=True) |
|
|
query = data.get("query", "").strip() |
|
|
lang = data.get("language", "en") |
|
|
|
|
|
if not query: |
|
|
return jsonify({"answer": "Please ask a valid question."}) |
|
|
|
|
|
|
|
|
if lang != "en": |
|
|
try: |
|
|
query_en = GoogleTranslator(source=lang, target="en").translate(query) |
|
|
except: |
|
|
query_en = query |
|
|
else: |
|
|
query_en = query |
|
|
|
|
|
|
|
|
try: |
|
|
result = qa_chain.invoke({"query": query_en}) |
|
|
answer_en = clean_llm_output(result.get("result", "")) |
|
|
except: |
|
|
answer_en = "" |
|
|
|
|
|
if not answer_en: |
|
|
answer_en = ( |
|
|
"I could not find confident information for this question. " |
|
|
"Please visit the official SASTRA website." |
|
|
) |
|
|
|
|
|
|
|
|
extracted_urls = extract_urls(answer_en) |
|
|
all_urls = list(set(extracted_urls + MANDATORY_URLS)) |
|
|
|
|
|
|
|
|
if lang != "en": |
|
|
try: |
|
|
final_answer = GoogleTranslator(source="en", target=lang).translate(answer_en) |
|
|
except: |
|
|
final_answer = answer_en |
|
|
else: |
|
|
final_answer = answer_en |
|
|
|
|
|
return jsonify({ |
|
|
"answer": final_answer, |
|
|
"urls": all_urls |
|
|
}) |
|
|
|
|
|
|
|
|
@app.route("/retrain", methods=["POST"]) |
|
|
def retrain(): |
|
|
data = request.get_json(force=True) |
|
|
context = data.get("context", "").strip() |
|
|
|
|
|
if not context: |
|
|
return jsonify({"error": "Context text required"}), 400 |
|
|
|
|
|
doc = Document( |
|
|
page_content=context, |
|
|
metadata={"source": "user_retrain"} |
|
|
) |
|
|
|
|
|
splitter = RecursiveCharacterTextSplitter( |
|
|
chunk_size=600, |
|
|
chunk_overlap=50 |
|
|
) |
|
|
chunks = splitter.split_documents([doc]) |
|
|
|
|
|
vectordb.add_documents(chunks) |
|
|
vectordb.persist() |
|
|
initialize_model() |
|
|
|
|
|
return jsonify({ |
|
|
"status": "success", |
|
|
"message": "Successfully retrained with your data" |
|
|
}) |