Spaces:
Build error
Build error
| import streamlit as st | |
| import os | |
| import re | |
| import json | |
| import demjson3 | |
| import requests | |
| import faiss | |
| import numpy as np | |
| import multiprocessing | |
| import time | |
| from huggingface_hub import hf_hub_download, login | |
| from sentence_transformers import SentenceTransformer | |
| from llama_cpp import Llama | |
| # --- Configuration & Constants --- | |
| # Get API key from Streamlit secrets | |
| try: | |
| GEOAPIFY_KEY = st.secrets["GEOAPIFY_KEY"] | |
| except KeyError: | |
| st.error("Geoapify API key not found. Please add it to Streamlit secrets.") | |
| st.stop() | |
| # HF_TOKEN is optional but recommended for downloads | |
| try: | |
| HF_TOKEN = st.secrets.get("HF_TOKEN") | |
| login(token=HF_TOKEN) | |
| except KeyError: | |
| HF_TOKEN = None | |
| # Model and RAG configuration | |
| MODEL_REPO_ID = "bartowski/gemma-2-2b-it-GGUF" | |
| MODEL_FILENAME = "gemma-2-2b-it-Q8_0.gguf" | |
| EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" | |
| DOCS_PATH = "docs" | |
| FAISS_INDEX_PATH = "bps_faiss.index" | |
| # LLM parameters | |
| N_CTX = 2048 | |
| MAX_TOKENS_RESPONSE = 350 | |
| TEMPERATURE = 0.5 | |
| N_THREADS = multiprocessing.cpu_count() - 1 | |
| # RAG parameters | |
| TOP_K_DOCS = 3 | |
| # Import prompts from prompts.py | |
| try: | |
| from prompts import system_prompt, json_prompt, initial_school_search_prompt | |
| except ImportError: | |
| st.error("Could not import prompts from prompts.py. Make sure the file exists.") | |
| system_prompt = """ | |
| You are a professional assistant that answers questions about enrollment in Boston Public Schools. | |
| Be friendly and helpful. Families will ask questions and provide information, such as the child's residence, grade, and school preference. | |
| It is essential that you do not provide inaccurate information. If you do not know something, respond accordingly! | |
| """ | |
| json_prompt = """ | |
| You must provide a response to the user and an updated JSON object with any new facts you learn. Only update fields that are explicitly mentioned. | |
| """ | |
| initial_school_search_prompt = """ | |
| Keep the conversation going and ask questions one at a time until you have all information. With residence and grade, you have the abilty to look up nearby schools. | |
| """ | |
| st.stop() | |
| # --- Helper Functions --- | |
| def clean_reply_text(reply: str) -> str: | |
| """Removes potential JSON blocks and cleans up common LLM artifacts.""" | |
| reply = re.sub(r"```[jJ][sS][oO][nN]?\s*(\{.*?\})\s*```", "", reply, flags=re.DOTALL) | |
| reply = re.sub(r"\s*\{.*\}\s*$", "", reply, flags=re.DOTALL) | |
| reply = re.sub(r"`", "", reply) | |
| reply = re.sub(r"(?i)\bjson\b", "", reply) | |
| reply = re.sub(r"[\[\]]", "", reply) | |
| reply = re.sub(r"\n{2,}", "\n", reply) | |
| return reply.strip() | |
| def extract_reply_and_json(text: str) -> tuple[str, dict]: | |
| """ | |
| Extracts the natural language reply and the last valid JSON object from the LLM response. | |
| Uses demjson3 for potentially more lenient parsing. | |
| """ | |
| json_part = {} | |
| reply_part = text | |
| last_brace_open = text.rfind('{') | |
| if last_brace_open != -1: | |
| brace_level = 0 | |
| last_brace_close = -1 | |
| potential_json_str = text[last_brace_open:] | |
| for i, char in enumerate(potential_json_str): | |
| if char == '{': | |
| brace_level += 1 | |
| elif char == '}': | |
| brace_level -= 1 | |
| if brace_level == 0: | |
| last_brace_close = last_brace_open + i | |
| break | |
| if last_brace_close != -1: | |
| json_str = text[last_brace_open : last_brace_close + 1] | |
| try: | |
| parsed = demjson3.decode(json_str) | |
| if isinstance(parsed, dict): | |
| json_part = parsed | |
| reply_part = text[:last_brace_open].strip() | |
| except demjson3.JSONDecodeError: | |
| pass | |
| cleaned_reply = clean_reply_text(reply_part) | |
| if not cleaned_reply and json_part: | |
| cleaned_reply = clean_reply_text(text[:last_brace_open]) | |
| elif not cleaned_reply and not json_part: | |
| cleaned_reply = "" | |
| return cleaned_reply, json_part | |
| def geocode_address(address: str) -> tuple[float | None, float | None]: | |
| """Turn a free‑form address into (lat, lon) using Geoapify.""" | |
| if not GEOAPIFY_KEY: | |
| return None, None | |
| try: | |
| resp = requests.get( | |
| "https://api.geoapify.com/v1/geocode/search", | |
| params={"text": address, "limit": 1, "apiKey": GEOAPIFY_KEY}, | |
| timeout=10 | |
| ) | |
| resp.raise_for_status() | |
| features = resp.json().get("features", []) | |
| if not features: | |
| return None, None | |
| lon, lat = features[0]["geometry"]["coordinates"] | |
| return lat, lon | |
| except requests.exceptions.RequestException as e: | |
| st.error(f"Geocoding API request failed: {e}") | |
| return None, None | |
| except Exception as e: | |
| st.error(f"Error processing geocoding response: {e}") | |
| return None, None | |
| def get_nearby_schools(address: str, radius: int = 2000, limit: int = 10) -> list[dict]: | |
| """Get nearby schools using Geoapify.""" | |
| if not GEOAPIFY_KEY: | |
| return [] | |
| lat, lon = geocode_address(f"{address}, Boston, MA, USA") | |
| if lat is None or lon is None: | |
| st.warning(f"Could not geocode address: {address}") | |
| return [] | |
| try: | |
| resp = requests.get( | |
| "https://api.geoapify.com/v2/places", | |
| params={ | |
| "categories": "education.school", | |
| "filter": f"circle:{lon},{lat},{radius}", | |
| "limit": limit, | |
| "apiKey": GEOAPIFY_KEY, | |
| }, | |
| timeout=10 | |
| ) | |
| resp.raise_for_status() | |
| schools = [] | |
| for feat in resp.json().get("features", []): | |
| prop = feat.get("properties", {}) | |
| name = prop.get("name") | |
| addr = prop.get("formatted") | |
| if name and addr: | |
| schools.append({"name": name, "address": addr}) | |
| return schools | |
| except requests.exceptions.RequestException as e: | |
| st.error(f"Nearby schools API request failed: {e}") | |
| return [] | |
| except Exception as e: | |
| st.error(f"Error processing nearby schools response: {e}") | |
| return [] | |
| def build_school_search_prompt(address: str) -> str: | |
| """Builds the prompt section listing nearby schools.""" | |
| if not address: | |
| return initial_school_search_prompt | |
| nearby_schools = get_nearby_schools(address, radius=2000, limit=10) | |
| if not nearby_schools: | |
| return f"No schools found near '{address}'. Please ensure the address is correct or try a broader area if applicable." | |
| school_list_str = "\n".join(f"- {s['name']}: {s['address']}" for s in nearby_schools) | |
| return ( | |
| f"Based on the residence '{address}', here are some nearby schools:\n{school_list_str}\n\n" | |
| "Use this information and the provided documents to answer eligibility questions for the user's grade level." | |
| ) | |
| def update_context(context_json: dict, new_data: dict) -> tuple[dict, bool]: | |
| """ | |
| Updates context_json in-place based on new_data extracted from LLM response. | |
| Returns the updated context and a boolean indicating if residence changed. | |
| """ | |
| residence_changed = False | |
| current_res = context_json.get("residence", "").strip() | |
| new_res = new_data.get("residence", "").strip() | |
| if new_res and new_res != current_res: | |
| context_json["residence"] = new_res | |
| residence_changed = True | |
| elif "residence" in new_data and not new_res and current_res: | |
| context_json["residence"] = "" | |
| residence_changed = True | |
| for key, value in new_data.items(): | |
| if key != "residence": | |
| new_val_str = str(value).strip() if value is not None else "" | |
| old_val_str = str(context_json.get(key, "")).strip() | |
| if new_val_str and new_val_str != old_val_str: | |
| context_json[key] = value | |
| elif key in new_data and not new_val_str and old_val_str: | |
| context_json[key] = "" | |
| return context_json, residence_changed | |
| # --- RAG Setup --- | |
| def load_embedding_model(): | |
| """Loads the Sentence Transformer model.""" | |
| try: | |
| return SentenceTransformer(EMBEDDING_MODEL_NAME) | |
| except Exception as e: | |
| st.error(f"Error loading embedding model '{EMBEDDING_MODEL_NAME}': {e}") | |
| st.stop() | |
| def load_documents(docs_path: str) -> tuple[list[str], list[str]]: | |
| """Loads text documents from the specified directory.""" | |
| doc_texts = [] | |
| filenames = [] | |
| if not os.path.isdir(docs_path): | |
| st.error(f"Documents directory '{docs_path}' not found. Please create it and add text files.") | |
| return [], [] | |
| try: | |
| for fname in os.listdir(docs_path): | |
| if fname.lower().endswith(".txt"): | |
| fpath = os.path.join(docs_path, fname) | |
| try: | |
| with open(fpath, 'r', encoding='utf-8') as f: | |
| doc_texts.append(f.read()) | |
| filenames.append(fname) | |
| except Exception as e: | |
| st.warning(f"Could not read file {fname}: {e}") | |
| if not doc_texts: | |
| st.warning(f"No .txt files found or loaded from '{docs_path}'. RAG will be ineffective.") | |
| return doc_texts, filenames | |
| except Exception as e: | |
| st.error(f"Error loading documents from '{docs_path}': {e}") | |
| return [], [] | |
| def create_faiss_index(_embedder, doc_texts): | |
| """Creates FAISS index from document texts.""" | |
| if not doc_texts: | |
| return None | |
| try: | |
| doc_embeddings = _embedder.encode(doc_texts, convert_to_numpy=True, show_progress_bar=True) | |
| if doc_embeddings is None or doc_embeddings.shape[0] == 0: | |
| st.error("Embedding failed, no document embeddings generated.") | |
| return None | |
| faiss.normalize_L2(doc_embeddings) | |
| dimension = doc_embeddings.shape[1] | |
| index = faiss.IndexFlatIP(dimension) | |
| index.add(doc_embeddings) | |
| return index | |
| except Exception as e: | |
| st.error(f"Error creating FAISS index: {e}") | |
| return None | |
| def query_docs(query: str, _index, _embedder, doc_texts, top_k=TOP_K_DOCS) -> list[str]: | |
| """Queries the FAISS index to retrieve relevant document chunks.""" | |
| if _index is None or not doc_texts: | |
| return [] | |
| try: | |
| query_embedding = _embedder.encode([query], convert_to_numpy=True) | |
| if query_embedding is None or query_embedding.shape[0] == 0: | |
| st.warning("Failed to generate query embedding.") | |
| return [] | |
| faiss.normalize_L2(query_embedding) | |
| distances, indices = _index.search(query_embedding, top_k) | |
| return [doc_texts[i] for i in indices[0] if i != -1] | |
| except Exception as e: | |
| st.error(f"Error querying FAISS index: {e}") | |
| return [] | |
| # --- LLM Loading --- | |
| def load_llm(): | |
| """Loads the Llama model using llama-cpp-python.""" | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO_ID, | |
| filename=MODEL_FILENAME, | |
| local_dir="models", | |
| local_dir_use_symlinks=False | |
| ) | |
| st.success(f"Model found at: {model_path}") | |
| except Exception as e: | |
| st.error(f"Error downloading model '{MODEL_FILENAME}' from '{MODEL_REPO_ID}': {e}") | |
| st.info("Please ensure the model repository and filename are correct, and you have internet access.") | |
| st.stop() | |
| try: | |
| llm = Llama( | |
| model_path=model_path, | |
| n_ctx=N_CTX, | |
| n_threads=N_THREADS, | |
| verbose=False | |
| ) | |
| return llm | |
| except Exception as e: | |
| st.error(f"Error loading Llama model from path '{model_path}': {e}") | |
| st.stop() | |
| # --- Prompt Building --- | |
| def build_full_prompt( | |
| context_json: dict, | |
| school_search_prompt: str, | |
| history: list[dict], | |
| max_history=5 | |
| ) -> str: | |
| """Builds the final prompt string for the LLM.""" | |
| last_user_input = "" | |
| if history and history[-1]["role"] == "user": | |
| last_user_input = history[-1]["content"] | |
| summary_info = context_json.get("summary", "") | |
| rag_query = f"{last_user_input}\n\nContext Summary: {summary_info}".strip() | |
| retrieved_docs = query_docs(rag_query, faiss_index, embedder, doc_texts_global, top_k=TOP_K_DOCS) | |
| docs_context_str = "\n\n---\n\n".join(retrieved_docs) | |
| if docs_context_str: | |
| docs_context_str = f"DOCUMENT CONTEXT:\n{docs_context_str}\n---" | |
| else: | |
| docs_context_str = "DOCUMENT CONTEXT: None available." | |
| recent_history = history[-(max_history * 2):] | |
| conversation = [] | |
| for msg in recent_history: | |
| role = "User" if msg["role"] == "user" else "Assistant" | |
| conversation.append(f"{role}: {msg['content']}") | |
| conversation_str = "\n".join(conversation) | |
| prompt = f"""{system_prompt} | |
| {docs_context_str} | |
| CURRENT SITUATION CONTEXT: | |
| {json.dumps(context_json, indent=2)} | |
| SCHOOL SEARCH INFO: | |
| {school_search_prompt} | |
| {json_prompt} | |
| CONVERSATION HISTORY: | |
| {conversation_str} | |
| Assistant:""" | |
| return prompt | |
| # --- Streamlit App UI and Logic --- | |
| st.set_page_config(page_title="Boston School Choice Chatbot", page_icon="🏫", layout="wide") | |
| st.title("Boston Public Schools Enrollment Assistant 🏫") | |
| st.markdown("Ask questions about enrolling in Boston Public Schools. I can help find nearby schools if you provide a residence address.") | |
| llm = load_llm() | |
| embedder = load_embedding_model() | |
| doc_texts_global, filenames_global = load_documents(DOCS_PATH) | |
| faiss_index = create_faiss_index(embedder, doc_texts_global) | |
| if "messages" not in st.session_state: | |
| st.session_state.messages = [] | |
| if "context_json" not in st.session_state: | |
| st.session_state.context_json = { | |
| "residence": "", | |
| "grade": "", | |
| "school_choice": "", | |
| "summary": "" | |
| } | |
| if "school_search" not in st.session_state: | |
| st.session_state.school_search = initial_school_search_prompt | |
| for message in st.session_state.messages: | |
| with st.chat_message(message["role"]): | |
| st.markdown(message["content"]) | |
| if prompt := st.chat_input("What is your question? (e.g., 'I live at 123 Main St, my child is going into grade 2')"): | |
| st.session_state.messages.append({"role": "user", "content": prompt}) | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| message_placeholder = st.empty() | |
| message_placeholder.markdown("Thinking...") | |
| full_prompt = build_full_prompt( | |
| st.session_state.context_json, | |
| st.session_state.school_search, | |
| st.session_state.messages | |
| ) | |
| try: | |
| response = llm( | |
| full_prompt, | |
| max_tokens=MAX_TOKENS_RESPONSE, | |
| temperature=TEMPERATURE, | |
| stop=["\nUser:", "\nAssistant:", "<|end_header_id|>", "<|eot_id|>"], | |
| echo=False | |
| ) | |
| raw_output = response["choices"][0]["text"].strip() | |
| reply_text, new_data = extract_reply_and_json(raw_output) | |
| updated_context, residence_changed = update_context(st.session_state.context_json, new_data) | |
| st.session_state.context_json = updated_context | |
| if residence_changed: | |
| st.session_state.school_search = build_school_search_prompt(st.session_state.context_json.get("residence", "")) | |
| message_placeholder.markdown(reply_text if reply_text else "_Assistant had trouble generating a response._") | |
| st.session_state.messages.append({"role": "assistant", "content": reply_text}) | |
| except Exception as e: | |
| st.error(f"An error occurred during response generation: {e}") | |
| error_message = "Sorry, I encountered an error processing your request." | |
| message_placeholder.markdown(error_message) | |
| st.session_state.messages.append({"role": "assistant", "content": error_message}) | |
| with st.sidebar: | |
| st.subheader("ℹ️ Current Context") | |
| st.json(st.session_state.context_json) | |
| st.subheader("🏫 School Search Status") | |
| st.text(st.session_state.school_search) | |