mtwesley's picture
removing too many comments
252940e
import streamlit as st
import os
import re
import json
import demjson3
import requests
import faiss
import numpy as np
import multiprocessing
import time
from huggingface_hub import hf_hub_download, login
from sentence_transformers import SentenceTransformer
from llama_cpp import Llama
# --- Configuration & Constants ---
# Get API key from Streamlit secrets
try:
GEOAPIFY_KEY = st.secrets["GEOAPIFY_KEY"]
except KeyError:
st.error("Geoapify API key not found. Please add it to Streamlit secrets.")
st.stop()
# HF_TOKEN is optional but recommended for downloads
try:
HF_TOKEN = st.secrets.get("HF_TOKEN")
login(token=HF_TOKEN)
except KeyError:
HF_TOKEN = None
# Model and RAG configuration
MODEL_REPO_ID = "bartowski/gemma-2-2b-it-GGUF"
MODEL_FILENAME = "gemma-2-2b-it-Q8_0.gguf"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
DOCS_PATH = "docs"
FAISS_INDEX_PATH = "bps_faiss.index"
# LLM parameters
N_CTX = 2048
MAX_TOKENS_RESPONSE = 350
TEMPERATURE = 0.5
N_THREADS = multiprocessing.cpu_count() - 1
# RAG parameters
TOP_K_DOCS = 3
# Import prompts from prompts.py
try:
from prompts import system_prompt, json_prompt, initial_school_search_prompt
except ImportError:
st.error("Could not import prompts from prompts.py. Make sure the file exists.")
system_prompt = """
You are a professional assistant that answers questions about enrollment in Boston Public Schools.
Be friendly and helpful. Families will ask questions and provide information, such as the child's residence, grade, and school preference.
It is essential that you do not provide inaccurate information. If you do not know something, respond accordingly!
"""
json_prompt = """
You must provide a response to the user and an updated JSON object with any new facts you learn. Only update fields that are explicitly mentioned.
"""
initial_school_search_prompt = """
Keep the conversation going and ask questions one at a time until you have all information. With residence and grade, you have the abilty to look up nearby schools.
"""
st.stop()
# --- Helper Functions ---
def clean_reply_text(reply: str) -> str:
"""Removes potential JSON blocks and cleans up common LLM artifacts."""
reply = re.sub(r"```[jJ][sS][oO][nN]?\s*(\{.*?\})\s*```", "", reply, flags=re.DOTALL)
reply = re.sub(r"\s*\{.*\}\s*$", "", reply, flags=re.DOTALL)
reply = re.sub(r"`", "", reply)
reply = re.sub(r"(?i)\bjson\b", "", reply)
reply = re.sub(r"[\[\]]", "", reply)
reply = re.sub(r"\n{2,}", "\n", reply)
return reply.strip()
def extract_reply_and_json(text: str) -> tuple[str, dict]:
"""
Extracts the natural language reply and the last valid JSON object from the LLM response.
Uses demjson3 for potentially more lenient parsing.
"""
json_part = {}
reply_part = text
last_brace_open = text.rfind('{')
if last_brace_open != -1:
brace_level = 0
last_brace_close = -1
potential_json_str = text[last_brace_open:]
for i, char in enumerate(potential_json_str):
if char == '{':
brace_level += 1
elif char == '}':
brace_level -= 1
if brace_level == 0:
last_brace_close = last_brace_open + i
break
if last_brace_close != -1:
json_str = text[last_brace_open : last_brace_close + 1]
try:
parsed = demjson3.decode(json_str)
if isinstance(parsed, dict):
json_part = parsed
reply_part = text[:last_brace_open].strip()
except demjson3.JSONDecodeError:
pass
cleaned_reply = clean_reply_text(reply_part)
if not cleaned_reply and json_part:
cleaned_reply = clean_reply_text(text[:last_brace_open])
elif not cleaned_reply and not json_part:
cleaned_reply = ""
return cleaned_reply, json_part
def geocode_address(address: str) -> tuple[float | None, float | None]:
"""Turn a free‑form address into (lat, lon) using Geoapify."""
if not GEOAPIFY_KEY:
return None, None
try:
resp = requests.get(
"https://api.geoapify.com/v1/geocode/search",
params={"text": address, "limit": 1, "apiKey": GEOAPIFY_KEY},
timeout=10
)
resp.raise_for_status()
features = resp.json().get("features", [])
if not features:
return None, None
lon, lat = features[0]["geometry"]["coordinates"]
return lat, lon
except requests.exceptions.RequestException as e:
st.error(f"Geocoding API request failed: {e}")
return None, None
except Exception as e:
st.error(f"Error processing geocoding response: {e}")
return None, None
def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> list[dict]:
"""Get nearby schools using Geoapify."""
if not GEOAPIFY_KEY:
return []
lat, lon = geocode_address(f"{address}, Boston, MA, USA")
if lat is None or lon is None:
st.warning(f"Could not geocode address: {address}")
return []
try:
resp = requests.get(
"https://api.geoapify.com/v2/places",
params={
"categories": "education.school",
"filter": f"circle:{lon},{lat},{radius}",
"limit": limit,
"apiKey": GEOAPIFY_KEY,
},
timeout=10
)
resp.raise_for_status()
schools = []
for feat in resp.json().get("features", []):
prop = feat.get("properties", {})
name = prop.get("name")
addr = prop.get("formatted")
if name and addr:
schools.append({"name": name, "address": addr})
return schools
except requests.exceptions.RequestException as e:
st.error(f"Nearby schools API request failed: {e}")
return []
except Exception as e:
st.error(f"Error processing nearby schools response: {e}")
return []
def build_school_search_prompt(address: str) -> str:
"""Builds the prompt section listing nearby schools."""
if not address:
return initial_school_search_prompt
nearby_schools = get_nearby_schools(address, radius=2000, limit=10)
if not nearby_schools:
return f"No schools found near '{address}'. Please ensure the address is correct or try a broader area if applicable."
school_list_str = "\n".join(f"- {s['name']}: {s['address']}" for s in nearby_schools)
return (
f"Based on the residence '{address}', here are some nearby schools:\n{school_list_str}\n\n"
"Use this information and the provided documents to answer eligibility questions for the user's grade level."
)
def update_context(context_json: dict, new_data: dict) -> tuple[dict, bool]:
"""
Updates context_json in-place based on new_data extracted from LLM response.
Returns the updated context and a boolean indicating if residence changed.
"""
residence_changed = False
current_res = context_json.get("residence", "").strip()
new_res = new_data.get("residence", "").strip()
if new_res and new_res != current_res:
context_json["residence"] = new_res
residence_changed = True
elif "residence" in new_data and not new_res and current_res:
context_json["residence"] = ""
residence_changed = True
for key, value in new_data.items():
if key != "residence":
new_val_str = str(value).strip() if value is not None else ""
old_val_str = str(context_json.get(key, "")).strip()
if new_val_str and new_val_str != old_val_str:
context_json[key] = value
elif key in new_data and not new_val_str and old_val_str:
context_json[key] = ""
return context_json, residence_changed
# --- RAG Setup ---
@st.cache_resource
def load_embedding_model():
"""Loads the Sentence Transformer model."""
try:
return SentenceTransformer(EMBEDDING_MODEL_NAME)
except Exception as e:
st.error(f"Error loading embedding model '{EMBEDDING_MODEL_NAME}': {e}")
st.stop()
@st.cache_data
def load_documents(docs_path: str) -> tuple[list[str], list[str]]:
"""Loads text documents from the specified directory."""
doc_texts = []
filenames = []
if not os.path.isdir(docs_path):
st.error(f"Documents directory '{docs_path}' not found. Please create it and add text files.")
return [], []
try:
for fname in os.listdir(docs_path):
if fname.lower().endswith(".txt"):
fpath = os.path.join(docs_path, fname)
try:
with open(fpath, 'r', encoding='utf-8') as f:
doc_texts.append(f.read())
filenames.append(fname)
except Exception as e:
st.warning(f"Could not read file {fname}: {e}")
if not doc_texts:
st.warning(f"No .txt files found or loaded from '{docs_path}'. RAG will be ineffective.")
return doc_texts, filenames
except Exception as e:
st.error(f"Error loading documents from '{docs_path}': {e}")
return [], []
@st.cache_resource(show_spinner="Creating document embeddings and FAISS index...")
def create_faiss_index(_embedder, doc_texts):
"""Creates FAISS index from document texts."""
if not doc_texts:
return None
try:
doc_embeddings = _embedder.encode(doc_texts, convert_to_numpy=True, show_progress_bar=True)
if doc_embeddings is None or doc_embeddings.shape[0] == 0:
st.error("Embedding failed, no document embeddings generated.")
return None
faiss.normalize_L2(doc_embeddings)
dimension = doc_embeddings.shape[1]
index = faiss.IndexFlatIP(dimension)
index.add(doc_embeddings)
return index
except Exception as e:
st.error(f"Error creating FAISS index: {e}")
return None
def query_docs(query: str, _index, _embedder, doc_texts, top_k=TOP_K_DOCS) -> list[str]:
"""Queries the FAISS index to retrieve relevant document chunks."""
if _index is None or not doc_texts:
return []
try:
query_embedding = _embedder.encode([query], convert_to_numpy=True)
if query_embedding is None or query_embedding.shape[0] == 0:
st.warning("Failed to generate query embedding.")
return []
faiss.normalize_L2(query_embedding)
distances, indices = _index.search(query_embedding, top_k)
return [doc_texts[i] for i in indices[0] if i != -1]
except Exception as e:
st.error(f"Error querying FAISS index: {e}")
return []
# --- LLM Loading ---
@st.cache_resource(show_spinner="Loading Language Model...")
def load_llm():
"""Loads the Llama model using llama-cpp-python."""
try:
model_path = hf_hub_download(
repo_id=MODEL_REPO_ID,
filename=MODEL_FILENAME,
local_dir="models",
local_dir_use_symlinks=False
)
st.success(f"Model found at: {model_path}")
except Exception as e:
st.error(f"Error downloading model '{MODEL_FILENAME}' from '{MODEL_REPO_ID}': {e}")
st.info("Please ensure the model repository and filename are correct, and you have internet access.")
st.stop()
try:
llm = Llama(
model_path=model_path,
n_ctx=N_CTX,
n_threads=N_THREADS,
verbose=False
)
return llm
except Exception as e:
st.error(f"Error loading Llama model from path '{model_path}': {e}")
st.stop()
# --- Prompt Building ---
def build_full_prompt(
context_json: dict,
school_search_prompt: str,
history: list[dict],
max_history=5
) -> str:
"""Builds the final prompt string for the LLM."""
last_user_input = ""
if history and history[-1]["role"] == "user":
last_user_input = history[-1]["content"]
summary_info = context_json.get("summary", "")
rag_query = f"{last_user_input}\n\nContext Summary: {summary_info}".strip()
retrieved_docs = query_docs(rag_query, faiss_index, embedder, doc_texts_global, top_k=TOP_K_DOCS)
docs_context_str = "\n\n---\n\n".join(retrieved_docs)
if docs_context_str:
docs_context_str = f"DOCUMENT CONTEXT:\n{docs_context_str}\n---"
else:
docs_context_str = "DOCUMENT CONTEXT: None available."
recent_history = history[-(max_history * 2):]
conversation = []
for msg in recent_history:
role = "User" if msg["role"] == "user" else "Assistant"
conversation.append(f"{role}: {msg['content']}")
conversation_str = "\n".join(conversation)
prompt = f"""{system_prompt}
{docs_context_str}
CURRENT SITUATION CONTEXT:
{json.dumps(context_json, indent=2)}
SCHOOL SEARCH INFO:
{school_search_prompt}
{json_prompt}
CONVERSATION HISTORY:
{conversation_str}
Assistant:"""
return prompt
# --- Streamlit App UI and Logic ---
st.set_page_config(page_title="Boston School Choice Chatbot", page_icon="🏫", layout="wide")
st.title("Boston Public Schools Enrollment Assistant 🏫")
st.markdown("Ask questions about enrolling in Boston Public Schools. I can help find nearby schools if you provide a residence address.")
llm = load_llm()
embedder = load_embedding_model()
doc_texts_global, filenames_global = load_documents(DOCS_PATH)
faiss_index = create_faiss_index(embedder, doc_texts_global)
if "messages" not in st.session_state:
st.session_state.messages = []
if "context_json" not in st.session_state:
st.session_state.context_json = {
"residence": "",
"grade": "",
"school_choice": "",
"summary": ""
}
if "school_search" not in st.session_state:
st.session_state.school_search = initial_school_search_prompt
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("What is your question? (e.g., 'I live at 123 Main St, my child is going into grade 2')"):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
message_placeholder = st.empty()
message_placeholder.markdown("Thinking...")
full_prompt = build_full_prompt(
st.session_state.context_json,
st.session_state.school_search,
st.session_state.messages
)
try:
response = llm(
full_prompt,
max_tokens=MAX_TOKENS_RESPONSE,
temperature=TEMPERATURE,
stop=["\nUser:", "\nAssistant:", "<|end_header_id|>", "<|eot_id|>"],
echo=False
)
raw_output = response["choices"][0]["text"].strip()
reply_text, new_data = extract_reply_and_json(raw_output)
updated_context, residence_changed = update_context(st.session_state.context_json, new_data)
st.session_state.context_json = updated_context
if residence_changed:
st.session_state.school_search = build_school_search_prompt(st.session_state.context_json.get("residence", ""))
message_placeholder.markdown(reply_text if reply_text else "_Assistant had trouble generating a response._")
st.session_state.messages.append({"role": "assistant", "content": reply_text})
except Exception as e:
st.error(f"An error occurred during response generation: {e}")
error_message = "Sorry, I encountered an error processing your request."
message_placeholder.markdown(error_message)
st.session_state.messages.append({"role": "assistant", "content": error_message})
with st.sidebar:
st.subheader("ℹ️ Current Context")
st.json(st.session_state.context_json)
st.subheader("🏫 School Search Status")
st.text(st.session_state.school_search)