ustp / app.py
anneee266333's picture
Update app.py
244fa02 verified
import os, pathlib, io, hashlib
from typing import List
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from langchain_community.llms import HuggingFacePipeline
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.prompts import ChatPromptTemplate
from langchain_text_splitters import RecursiveCharacterTextSplitter
from pypdf import PdfReader
# CACHE SETUP (makes Spaces happy)
def _ensure_dir(p: str) -> str:
pathlib.Path(p).mkdir(parents=True, exist_ok=True)
return p
for candidate in (
os.getenv("HF_CACHE_DIR"),
"/data/.cache/huggingface",
os.path.expanduser("~/.cache/huggingface"),
"/home/user/.cache/huggingface",
"/tmp/hf_cache",
):
if not candidate:
continue
try:
HF_CACHE_DIR = _ensure_dir(candidate)
break
except Exception:
continue
os.environ["HF_HOME"] = HF_CACHE_DIR
os.environ["HUGGINGFACE_HUB_CACHE"] = HF_CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = _ensure_dir(os.path.join(HF_CACHE_DIR, "transformers"))
os.environ["SENTENCE_TRANSFORMERS_HOME"] = _ensure_dir(os.path.join(HF_CACHE_DIR, "sentence_transformers"))
os.environ["TORCH_HOME"] = _ensure_dir(os.path.join(HF_CACHE_DIR, "torch"))
# CONFIG
st.set_page_config(page_title="USTP Handbook Chatbot", page_icon="📚", layout="wide")
DEFAULT_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_MAX_NEW_TOKENS = 256
DEFAULT_TEMPERATURE = 0.2
HANDBOOK_PATH = "USTP Student Handbook 2023 Edition.pdf" # your uploaded handbook filename
SYSTEM_PROMPT = (
"You are a friendly and factual assistant answering student queries about the USTP Student Handbook. "
"Use ONLY the provided context to answer. If the answer isn’t in the context, say you don’t know. "
"Be concise and, when possible, mention the chunk number."
)
# UTILITIES
def read_pdf_bytes_to_text(file_like: io.BytesIO) -> str:
file_like.seek(0)
reader = PdfReader(file_like)
texts = []
for page in reader.pages:
texts.append(page.extract_text() or "")
return "\n".join(texts)
def compute_texts_hash(texts: List[str]) -> str:
data = "\n".join(texts)
return hashlib.sha256(data.encode("utf-8")).hexdigest()
def format_docs(docs):
return "\n\n".join(f"[{i+1}] {d.page_content}" for i, d in enumerate(docs))
# CACHED FUNCTIONS
@st.cache_resource(show_spinner=True)
def get_embeddings():
from langchain_huggingface import HuggingFaceEmbeddings
return HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
cache_folder=os.environ.get("SENTENCE_TRANSFORMERS_HOME"),
model_kwargs={"local_files_only": False},
)
@st.cache_resource(show_spinner=True)
def load_llm(model_id: str = DEFAULT_MODEL_ID,
temperature: float = DEFAULT_TEMPERATURE,
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS):
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
)
gen = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=-1,
do_sample=(temperature > 0.0),
temperature=temperature,
max_new_tokens=max_new_tokens,
repetition_penalty=1.1,
pad_token_id=tokenizer.eos_token_id,
return_full_text=False,
)
return HuggingFacePipeline(pipeline=gen)
def build_faiss_index(texts: List[str], chunk_size: int = 800, chunk_overlap: int = 120):
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
docs = splitter.create_documents(texts)
emb = get_embeddings()
vs = FAISS.from_documents(docs, embedding=emb)
return vs
def make_rag_chain(retriever, llm):
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_PROMPT),
("human", "Context:\n{context}\n\nQuestion: {question}")
])
chain = (
{
"context": retriever | RunnableLambda(format_docs),
"question": RunnablePassthrough()
}
| prompt
| llm
| StrOutputParser()
)
return chain
# LOAD KNOWLEDGE BASE AUTOMATICALLY
@st.cache_resource(show_spinner=True)
def load_handbook_kb():
try:
with open(HANDBOOK_PATH, "rb") as f:
text = read_pdf_bytes_to_text(io.BytesIO(f.read()))
kb_hash = compute_texts_hash([text])
with st.spinner("Embedding & indexing USTP Handbook…"):
vs = build_faiss_index([text])
return vs, kb_hash
except Exception as e:
st.error(f"⚠️ Failed to load handbook: {e}")
return None, None
# UI
st.title("🎓 USTP Student Handbook Chatbot")
st.write(
"I'm your AI assistant trained on the **USTP Student Handbook**. "
"Ask me anything about school policies, rules, or student life!"
)
# Sidebar controls
with st.sidebar:
st.header("⚙️ Model & Settings")
model_id = st.text_input(
"Model ID (Recommend: mistralai/Mistral-7B-Instruct-v0.2)",
value=DEFAULT_MODEL_ID,
)
temperature = st.slider("Temperature", 0.0, 1.0, DEFAULT_TEMPERATURE, 0.05)
max_new_tokens = st.slider("Max tokens", 32, 1024, DEFAULT_MAX_NEW_TOKENS, 32)
k = st.slider("Number of chunks (k)", 1, 10, 4, 1)
# Initialize knowledge base
if "vectorstore" not in st.session_state:
vs, kb_hash = load_handbook_kb()
if vs:
st.session_state["vectorstore"] = vs
st.session_state["kb_hash"] = kb_hash
# st.success("✅ Handbook knowledge base loaded successfully!")
# Initialize LLM
if "llm" not in st.session_state:
with st.spinner("Loading model…"):
st.session_state["llm"] = load_llm(model_id, temperature, max_new_tokens)
from streamlit_chat import message
# --- Chat Interface ---
st.markdown("---")
# --- Initialize Chat History ---
if "chat_history" not in st.session_state:
st.session_state.chat_history = []
# --- Display Existing Messages ---
for i, chat in enumerate(st.session_state.chat_history):
if chat["role"] == "user":
message(chat["content"], is_user=True, key=f"user_{i}", avatar_style="thumbs")
else:
# assistant’s main message
message(chat["content"], is_user=False, key=f"bot_{i}", avatar_style="big-smile")
# source chunk collapsible below
if chat.get("source"):
with st.expander("📘 Source Chunk"):
st.markdown(chat["source"])
# --- Input Field (modern Streamlit style) ---
user_question = st.chat_input("Ask about USTP policies, rules, or student life...")
if user_question:
# store user message
st.session_state.chat_history.append({"role": "user", "content": user_question})
if "vectorstore" not in st.session_state:
st.warning("Knowledge base not loaded yet.")
else:
vs = st.session_state["vectorstore"]
llm = st.session_state["llm"]
retriever = vs.as_retriever(search_type="similarity", search_kwargs={"k": k})
chain = make_rag_chain(retriever, llm)
with st.spinner("USTP Student Handbook Chatbot is reviewing your request..."):
answer = chain.invoke(user_question)
# remove unwanted "Answer:" prefix
if answer.strip().lower().startswith("answer:"):
answer = answer.split(":", 1)[1].strip()
# retrieve related chunk
docs = retriever.invoke(user_question)
source_chunk = docs[0].page_content[:800] if docs else "No source found."
# store assistant’s message
st.session_state.chat_history.append({
"role": "assistant",
"content": answer,
"source": source_chunk
})
st.rerun()