Spaces:
Sleeping
Sleeping
| from huggingface_hub import InferenceClient | |
| from config import BASE_MODEL, MY_MODEL, HF_TOKEN | |
| import pandas as pd | |
| import os | |
| from src.rag_engine import RAGEngine, SchoolDocument | |
| class SchoolChatbot: | |
| """ | |
| This class is extra scaffolding around a model. Modify this class to specify how the model recieves prompts and generates responses. | |
| Example usage: | |
| chatbot = SchoolChatbot() | |
| response = chatbot.get_response("What schools offer Spanish programs?") | |
| """ | |
| def __init__(self, school_csv='BPS.csv', programs_csv='BPS-special-programs.csv'): | |
| """ | |
| Initialize the chatbot with a HF model ID | |
| """ | |
| model_id = MY_MODEL if MY_MODEL else BASE_MODEL # define MY_MODEL in config.py if you create a new model in the HuggingFace Hub | |
| self.client = InferenceClient(model=model_id, token=HF_TOKEN) | |
| self.school_csv = school_csv | |
| self.programs_csv = programs_csv | |
| # Initialize the RAG engine | |
| self.rag_engine = RAGEngine() | |
| # Set up the RAG index | |
| self._setup_rag() | |
| def _setup_rag(self): | |
| """ | |
| Set up the RAG engine by either loading a pre-built index or building a new one. | |
| """ | |
| index_dir = 'models' | |
| index_path = os.path.join(index_dir, 'school_rag') | |
| # Check if index files exist | |
| if (os.path.exists(f"{index_path}_documents.pkl") and | |
| os.path.exists(f"{index_path}_embeddings.pkl") and | |
| os.path.exists(f"{index_path}_faiss.index")): | |
| # Load existing index | |
| try: | |
| self.rag_engine.load_index(index_path) | |
| print("Loaded existing RAG index.") | |
| return | |
| except Exception as e: | |
| print(f"Error loading index: {e}. Building new index...") | |
| # Build new index | |
| os.makedirs(index_dir, exist_ok=True) | |
| self.rag_engine.process_school_data(self.school_csv, self.programs_csv) | |
| self.rag_engine.build_index(index_path) | |
| print("Built and saved new RAG index.") | |
| def load_age_cutoffs(filepath='age_cutoffs_2025.txt'): | |
| try: | |
| with open(filepath, 'r', encoding='utf-8') as f: | |
| return f.read() | |
| except FileNotFoundError: | |
| return "# AGE_CUTOFFS\n<Error: age cutoff file not found>" | |
| def format_school_data( | |
| school_csv='BPS.csv', | |
| programs_csv='BPS-special-programs.csv', | |
| ): | |
| """ | |
| Merges main school data with special program indicators and formats it for in-context prompting. | |
| Args: | |
| school_csv (str): Path to the main school data CSV. | |
| programs_csv (str): Path to the special programs CSV. | |
| max_schools (int or None): Max number of schools to include. | |
| Returns: | |
| str: Formatted string for # SCHOOL_DATA section. | |
| """ | |
| try: | |
| # Load both datasets | |
| schools_df = pd.read_csv(school_csv) | |
| programs_df = pd.read_csv(programs_csv) | |
| # Merge on School Name | |
| merged_df = pd.merge(schools_df, programs_df, on="School Name", how="left") | |
| # Use more concise formatting | |
| school_lines = [] | |
| for _, row in merged_df.iterrows(): | |
| # Collect all programs marked "Yes" | |
| programs_offered = [col for col in programs_df.columns[1:] if row.get(col, "") == "Yes"] | |
| programs_str = "Y" if programs_offered else "N" | |
| school_lines.append( | |
| f'- {row["School Name"]}: {row["Grades Served"]}, {row["School Type"]}, {programs_str}' | |
| ) | |
| school_lines = list(set(school_lines)) # Remove duplicates | |
| return "# SCHOOL_DATA\n" + "\n".join(school_lines) | |
| except Exception as e: | |
| return f"# SCHOOL_DATA\n<Error loading or merging data: {e}>" | |
| def format_prompt(self, user_input): | |
| """ | |
| Format the user's input into a proper prompt using RAG to retrieve relevant context. | |
| Args: | |
| user_input (str): The user's question about Boston schools | |
| Returns: | |
| str: A formatted prompt ready for the model | |
| """ | |
| system_message = """You are a helpful and accurate school enrollment assistant for Boston Public Schools (BPS). | |
| You can provide information about school options, locations, programs, and other details | |
| to help families make informed decisions about their children's education. | |
| Provide clear, fact-based, and non-misleading information using the data provided below. | |
| Focus on answering only the user's specific question using the relevant school information. | |
| When answering questions about specific schools, neighborhoods, or programs, prioritize information | |
| from the RETRIEVED_SCHOOLS section, which contains the most relevant schools for the user's query. | |
| DO NOT make up or hallucinate any school information. | |
| If the retrieved schools don't match what the user is looking for, acknowledge this limitation | |
| and suggest they contact BPS directly at (617) 635-9010 for more information. | |
| """ | |
| age_cutoffs_section = SchoolChatbot.load_age_cutoffs() | |
| transportation_section = """# TRANSPORTATION_ELIGIBILITY | |
| - K0–K1: Bus eligible if >0.75 miles from school | |
| - K2–5: Bus eligible if >1 mile | |
| - Grades 6–8: Bus eligible if >1.5 miles | |
| - Grades 9–12: MBTA pass provided | |
| """ | |
| # Instead of including all school data, retrieve relevant schools using RAG | |
| retrieved_docs = self.rag_engine.retrieve(user_input, top_k=3) | |
| retrieved_context = self.rag_engine.format_retrieved_context(retrieved_docs) | |
| # Comment out the full dataset reference to reduce token usage | |
| # school_data_section = SchoolChatbot.format_school_data( | |
| # school_csv=self.school_csv, | |
| # programs_csv=self.programs_csv, | |
| # ) | |
| examples_section = """# EXAMPLES | |
| User: My child is turning 5 on August 15 and we live in 02124. What grade can they enter, and what schools are available? | |
| Assistant: Since your child turns 5 before September 1, they are eligible for K2. Based on your zip code (02124), eligible schools may include Joseph Lee K-8, Mildred Avenue, and TechBoston Academy. | |
| """ | |
| # Combine all sections into the final prompt | |
| # f"{school_data_section}\n" # Comment out the full dataset section | |
| prompt = ( | |
| f"<|system|>\n{system_message}\n" | |
| f"{age_cutoffs_section}\n" | |
| f"{transportation_section}\n" | |
| f"{retrieved_context}\n" | |
| f"{examples_section}\n" | |
| f"<|user|>\n{user_input}\n<|assistant|>\n" | |
| ) | |
| print(prompt) | |
| return prompt | |
| def get_response(self, user_input): | |
| """ | |
| Generate responses to user questions using RAG and the language model. | |
| Args: | |
| user_input (str): The user's question about Boston schools | |
| Returns: | |
| str: The chatbot's response | |
| """ | |
| prompt = self.format_prompt(user_input) | |
| # Generate response using the model | |
| response = self.client.text_generation( | |
| prompt, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| do_sample=True, | |
| repetition_penalty=1.1 | |
| ) | |
| return response | |