FlaskAsk / app.py
bshk57's picture
Update app.py
ffc9415 verified
from flask import Flask, request, jsonify
from flask_cors import CORS
import os
import re
import pandas as pd
from typing import List
from deep_translator import GoogleTranslator
from langchain_core.documents import Document
from langchain_community.document_loaders import (
WebBaseLoader, PyPDFLoader, Docx2txtLoader
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
from huggingface_hub import snapshot_download
# ============================================================
# APP INIT
# ============================================================
app = Flask(__name__)
CORS(app)
# ============================================================
# CONSTANTS
# ============================================================
DATASET_REPO = "bshk57/Sastra_data"
LOCAL_DATASET_DIR = "knowledge_base"
VECTOR_DB_PATH = "sastra_vector_db"
EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
LLM_MODEL = "google/flan-t5-large"
MANDATORY_URLS = [
"https://www.sastra.edu/admissions/ug-pg.html",
"https://www.sastra.edu/admissions/eligibility-criteria.html",
"https://www.sastra.edu/admissions/fee-structure.html",
]
SASTRA_URLS = [
"https://www.sastra.edu/about-us.html",
"https://www.sastra.edu/academics/schools.html#school-of-computing",
"https://www.sastra.edu/admissions/ug-pg.html",
"https://www.sastra.edu/admissions/eligibility-criteria.html",
"https://www.sastra.edu/admissions/fee-structure.html",
"https://www.sastra.edu/admissions/hostel-fees.html",
"https://www.sastra.edu/infrastructure/physical-facilities.html",
"https://www.sastra.edu/about-us/mission-vision.html",
]
os.makedirs(LOCAL_DATASET_DIR, exist_ok=True)
# ============================================================
# UTILITIES
# ============================================================
def extract_urls(text: str):
return re.findall(r'https?://[^\s]+', text)
def clean_llm_output(text: str):
text = re.sub(r'^(Answer:|Response:)', '', text, flags=re.I).strip()
if text.lower().startswith("insufficient_data"):
return ""
return re.sub(r'\s+', ' ', text)[:700]
# ============================================================
# LOAD DATASET FROM HUGGINGFACE
# ============================================================
print("โฌ‡ Downloading dataset...")
snapshot_download(
repo_id=DATASET_REPO,
repo_type="dataset",
local_dir=LOCAL_DATASET_DIR,
local_dir_use_symlinks=False,
ignore_patterns=[".gitattributes"]
)
# ============================================================
# LOAD LOCAL DOCUMENTS
# ============================================================
def load_local_documents():
docs = []
for file in os.listdir(LOCAL_DATASET_DIR):
path = os.path.join(LOCAL_DATASET_DIR, file)
try:
if file.lower().endswith(".pdf"):
docs.extend(PyPDFLoader(path).load())
elif file.lower().endswith(".docx"):
docs.extend(Docx2txtLoader(path).load())
elif file.lower().endswith(".xlsx"):
df = pd.read_excel(path)
for _, row in df.iterrows():
text = " | ".join(
f"{col}: {row[col]}"
for col in df.columns
if pd.notna(row[col])
)
docs.append(Document(
page_content=text,
metadata={"source": file}
))
except Exception as e:
print(f"โš  Error loading {file}: {e}")
return docs
# ============================================================
# INITIALIZE RAG MODEL
# ============================================================
def initialize_model():
global vectordb, qa_chain
documents = []
# Website data
for url in SASTRA_URLS:
try:
documents.extend(WebBaseLoader(url).load())
except:
pass
# Dataset + uploaded docs
documents.extend(load_local_documents())
splitter = RecursiveCharacterTextSplitter(
chunk_size=600,
chunk_overlap=50
)
chunks = splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(
model_name=EMBEDDING_MODEL
)
vectordb = Chroma.from_documents(
chunks,
embeddings,
persist_directory=VECTOR_DB_PATH
)
retriever = vectordb.as_retriever(search_kwargs={"k": 4})
generator = pipeline(
"text2text-generation",
model=LLM_MODEL,
tokenizer=LLM_MODEL,
max_new_tokens=200,
temperature=0.1,
top_p=0.9,
repetition_penalty=1.2
)
llm = HuggingFacePipeline(pipeline=generator)
prompt = PromptTemplate(
input_variables=["context", "question"],
template="""
You are AskSASTRA, the official SASTRA University assistant.
Answer strictly from the given context.
If information is missing, say INSUFFICIENT_DATA.
Context:
{context}
Question:
{question}
Answer:
"""
)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
retriever=retriever,
chain_type="stuff",
chain_type_kwargs={"prompt": prompt},
return_source_documents=False
)
print("โœ… AskSASTRA initialized")
initialize_model()
# ============================================================
# ROUTES
# ============================================================
@app.route("/", methods=["GET"])
def home():
return jsonify({"message": "AskSASTRA API running ๐Ÿš€"})
@app.route("/chat", methods=["POST"])
def chat():
data = request.get_json(force=True)
query = data.get("query", "").strip()
lang = data.get("language", "en")
if not query:
return jsonify({"answer": "Please ask a valid question."})
# Translate input โ†’ English
if lang != "en":
try:
query_en = GoogleTranslator(source=lang, target="en").translate(query)
except:
query_en = query
else:
query_en = query
# RAG inference
try:
result = qa_chain.invoke({"query": query_en})
answer_en = clean_llm_output(result.get("result", ""))
except:
answer_en = ""
if not answer_en:
answer_en = (
"I could not find confident information for this question. "
"Please visit the official SASTRA website."
)
# URL handling
extracted_urls = extract_urls(answer_en)
all_urls = list(set(extracted_urls + MANDATORY_URLS))
# Translate output back
if lang != "en":
try:
final_answer = GoogleTranslator(source="en", target=lang).translate(answer_en)
except:
final_answer = answer_en
else:
final_answer = answer_en
return jsonify({
"answer": final_answer,
"urls": all_urls
})
@app.route("/retrain", methods=["POST"])
def retrain():
data = request.get_json(force=True)
context = data.get("context", "").strip()
if not context:
return jsonify({"error": "Context text required"}), 400
doc = Document(
page_content=context,
metadata={"source": "user_retrain"}
)
splitter = RecursiveCharacterTextSplitter(
chunk_size=600,
chunk_overlap=50
)
chunks = splitter.split_documents([doc])
vectordb.add_documents(chunks)
vectordb.persist()
initialize_model()
return jsonify({
"status": "success",
"message": "Successfully retrained with your data"
})