Spaces:
Sleeping
Sleeping
File size: 7,640 Bytes
e9e366a 562ed56 34345fa 86e3856 abf7d79 e9e366a abf7d79 562ed56 abf7d79 e9e366a abf7d79 e9e366a 562ed56 34345fa 99b75a1 abf7d79 562ed56 2817408 562ed56 2817408 562ed56 2817408 562ed56 2817408 562ed56 abf7d79 34345fa abf7d79 99b75a1 40f76db 34345fa 99b75a1 34345fa 99b75a1 34345fa 99b75a1 34345fa 99b75a1 34345fa 40f76db 99b75a1 abf7d79 34345fa abf7d79 40f76db 99b75a1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | from huggingface_hub import InferenceClient
from config import BASE_MODEL, MY_MODEL, HF_TOKEN
import pandas as pd
import os
from src.rag_engine import RAGEngine, SchoolDocument
class SchoolChatbot:
"""
This class is extra scaffolding around a model. Modify this class to specify how the model recieves prompts and generates responses.
Example usage:
chatbot = SchoolChatbot()
response = chatbot.get_response("What schools offer Spanish programs?")
"""
def __init__(self, school_csv='BPS.csv', programs_csv='BPS-special-programs.csv'):
"""
Initialize the chatbot with a HF model ID
"""
model_id = MY_MODEL if MY_MODEL else BASE_MODEL # define MY_MODEL in config.py if you create a new model in the HuggingFace Hub
self.client = InferenceClient(model=model_id, token=HF_TOKEN)
self.school_csv = school_csv
self.programs_csv = programs_csv
# Initialize the RAG engine
self.rag_engine = RAGEngine()
# Set up the RAG index
self._setup_rag()
def _setup_rag(self):
"""
Set up the RAG engine by either loading a pre-built index or building a new one.
"""
index_dir = 'models'
index_path = os.path.join(index_dir, 'school_rag')
# Check if index files exist
if (os.path.exists(f"{index_path}_documents.pkl") and
os.path.exists(f"{index_path}_embeddings.pkl") and
os.path.exists(f"{index_path}_faiss.index")):
# Load existing index
try:
self.rag_engine.load_index(index_path)
print("Loaded existing RAG index.")
return
except Exception as e:
print(f"Error loading index: {e}. Building new index...")
# Build new index
os.makedirs(index_dir, exist_ok=True)
self.rag_engine.process_school_data(self.school_csv, self.programs_csv)
self.rag_engine.build_index(index_path)
print("Built and saved new RAG index.")
@staticmethod
def load_age_cutoffs(filepath='age_cutoffs_2025.txt'):
try:
with open(filepath, 'r', encoding='utf-8') as f:
return f.read()
except FileNotFoundError:
return "# AGE_CUTOFFS\n<Error: age cutoff file not found>"
@staticmethod
def format_school_data(
school_csv='BPS.csv',
programs_csv='BPS-special-programs.csv',
):
"""
Merges main school data with special program indicators and formats it for in-context prompting.
Args:
school_csv (str): Path to the main school data CSV.
programs_csv (str): Path to the special programs CSV.
max_schools (int or None): Max number of schools to include.
Returns:
str: Formatted string for # SCHOOL_DATA section.
"""
try:
# Load both datasets
schools_df = pd.read_csv(school_csv)
programs_df = pd.read_csv(programs_csv)
# Merge on School Name
merged_df = pd.merge(schools_df, programs_df, on="School Name", how="left")
# Use more concise formatting
school_lines = []
for _, row in merged_df.iterrows():
# Collect all programs marked "Yes"
programs_offered = [col for col in programs_df.columns[1:] if row.get(col, "") == "Yes"]
programs_str = "Y" if programs_offered else "N"
school_lines.append(
f'- {row["School Name"]}: {row["Grades Served"]}, {row["School Type"]}, {programs_str}'
)
school_lines = list(set(school_lines)) # Remove duplicates
return "# SCHOOL_DATA\n" + "\n".join(school_lines)
except Exception as e:
return f"# SCHOOL_DATA\n<Error loading or merging data: {e}>"
def format_prompt(self, user_input):
"""
Format the user's input into a proper prompt using RAG to retrieve relevant context.
Args:
user_input (str): The user's question about Boston schools
Returns:
str: A formatted prompt ready for the model
"""
system_message = """You are a helpful and accurate school enrollment assistant for Boston Public Schools (BPS).
You can provide information about school options, locations, programs, and other details
to help families make informed decisions about their children's education.
Provide clear, fact-based, and non-misleading information using the data provided below.
Focus on answering only the user's specific question using the relevant school information.
When answering questions about specific schools, neighborhoods, or programs, prioritize information
from the RETRIEVED_SCHOOLS section, which contains the most relevant schools for the user's query.
DO NOT make up or hallucinate any school information.
If the retrieved schools don't match what the user is looking for, acknowledge this limitation
and suggest they contact BPS directly at (617) 635-9010 for more information.
"""
age_cutoffs_section = SchoolChatbot.load_age_cutoffs()
transportation_section = """# TRANSPORTATION_ELIGIBILITY
- K0–K1: Bus eligible if >0.75 miles from school
- K2–5: Bus eligible if >1 mile
- Grades 6–8: Bus eligible if >1.5 miles
- Grades 9–12: MBTA pass provided
"""
# Instead of including all school data, retrieve relevant schools using RAG
retrieved_docs = self.rag_engine.retrieve(user_input, top_k=3)
retrieved_context = self.rag_engine.format_retrieved_context(retrieved_docs)
# Comment out the full dataset reference to reduce token usage
# school_data_section = SchoolChatbot.format_school_data(
# school_csv=self.school_csv,
# programs_csv=self.programs_csv,
# )
examples_section = """# EXAMPLES
User: My child is turning 5 on August 15 and we live in 02124. What grade can they enter, and what schools are available?
Assistant: Since your child turns 5 before September 1, they are eligible for K2. Based on your zip code (02124), eligible schools may include Joseph Lee K-8, Mildred Avenue, and TechBoston Academy.
"""
# Combine all sections into the final prompt
# f"{school_data_section}\n" # Comment out the full dataset section
prompt = (
f"<|system|>\n{system_message}\n"
f"{age_cutoffs_section}\n"
f"{transportation_section}\n"
f"{retrieved_context}\n"
f"{examples_section}\n"
f"<|user|>\n{user_input}\n<|assistant|>\n"
)
print(prompt)
return prompt
def get_response(self, user_input):
"""
Generate responses to user questions using RAG and the language model.
Args:
user_input (str): The user's question about Boston schools
Returns:
str: The chatbot's response
"""
prompt = self.format_prompt(user_input)
# Generate response using the model
response = self.client.text_generation(
prompt,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
repetition_penalty=1.1
)
return response
|