File size: 7,640 Bytes
e9e366a
 
562ed56
34345fa
 
86e3856
abf7d79
 
 
 
 
e9e366a
abf7d79
 
 
562ed56
abf7d79
e9e366a
abf7d79
e9e366a
 
562ed56
 
34345fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99b75a1
 
 
 
 
 
 
 
abf7d79
562ed56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2817408
562ed56
 
2817408
562ed56
 
 
 
2817408
562ed56
 
2817408
 
 
562ed56
 
 
 
 
 
abf7d79
 
34345fa
abf7d79
 
 
 
 
 
 
99b75a1
40f76db
34345fa
 
 
 
 
 
 
 
 
 
 
 
99b75a1
 
 
 
 
 
 
 
 
 
34345fa
 
 
 
 
 
 
 
 
 
99b75a1
 
 
 
 
 
 
34345fa
99b75a1
 
 
 
34345fa
99b75a1
 
 
 
34345fa
40f76db
99b75a1
 
abf7d79
 
 
34345fa
abf7d79
 
 
 
 
 
 
40f76db
 
 
 
 
 
 
 
 
 
 
99b75a1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
from huggingface_hub import InferenceClient
from config import BASE_MODEL, MY_MODEL, HF_TOKEN
import pandas as pd
import os
from src.rag_engine import RAGEngine, SchoolDocument

class SchoolChatbot:
    """
    This class is extra scaffolding around a model. Modify this class to specify how the model recieves prompts and generates responses.

    Example usage:
        chatbot = SchoolChatbot()
        response = chatbot.get_response("What schools offer Spanish programs?")
    """

    def __init__(self, school_csv='BPS.csv', programs_csv='BPS-special-programs.csv'):
        """
        Initialize the chatbot with a HF model ID
        """
        model_id = MY_MODEL if MY_MODEL else BASE_MODEL # define MY_MODEL in config.py if you create a new model in the HuggingFace Hub
        self.client = InferenceClient(model=model_id, token=HF_TOKEN)
        self.school_csv = school_csv
        self.programs_csv = programs_csv
        
        # Initialize the RAG engine
        self.rag_engine = RAGEngine()
        
        # Set up the RAG index
        self._setup_rag()

    def _setup_rag(self):
        """
        Set up the RAG engine by either loading a pre-built index or building a new one.
        """
        index_dir = 'models'
        index_path = os.path.join(index_dir, 'school_rag')
        
        # Check if index files exist
        if (os.path.exists(f"{index_path}_documents.pkl") and 
            os.path.exists(f"{index_path}_embeddings.pkl") and 
            os.path.exists(f"{index_path}_faiss.index")):
            # Load existing index
            try:
                self.rag_engine.load_index(index_path)
                print("Loaded existing RAG index.")
                return
            except Exception as e:
                print(f"Error loading index: {e}. Building new index...")
        
        # Build new index
        os.makedirs(index_dir, exist_ok=True)
        self.rag_engine.process_school_data(self.school_csv, self.programs_csv)
        self.rag_engine.build_index(index_path)
        print("Built and saved new RAG index.")

    @staticmethod
    def load_age_cutoffs(filepath='age_cutoffs_2025.txt'):
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                return f.read()
        except FileNotFoundError:
            return "# AGE_CUTOFFS\n<Error: age cutoff file not found>"
        
    @staticmethod
    def format_school_data(
        school_csv='BPS.csv',
        programs_csv='BPS-special-programs.csv',
    ):
        """
        Merges main school data with special program indicators and formats it for in-context prompting.

        Args:
            school_csv (str): Path to the main school data CSV.
            programs_csv (str): Path to the special programs CSV.
            max_schools (int or None): Max number of schools to include.

        Returns:
            str: Formatted string for # SCHOOL_DATA section.
        """
        try:
            # Load both datasets
            schools_df = pd.read_csv(school_csv)
            programs_df = pd.read_csv(programs_csv)

            # Merge on School Name
            merged_df = pd.merge(schools_df, programs_df, on="School Name", how="left")

            # Use more concise formatting
            school_lines = []
            for _, row in merged_df.iterrows():
                # Collect all programs marked "Yes"
                programs_offered = [col for col in programs_df.columns[1:] if row.get(col, "") == "Yes"]
                programs_str = "Y" if programs_offered else "N"

                school_lines.append(
                    f'- {row["School Name"]}: {row["Grades Served"]}, {row["School Type"]}, {programs_str}'
                )    
            school_lines = list(set(school_lines))  # Remove duplicates
            return "# SCHOOL_DATA\n" + "\n".join(school_lines)

        except Exception as e:
            return f"# SCHOOL_DATA\n<Error loading or merging data: {e}>"

        
    def format_prompt(self, user_input):
        """
        Format the user's input into a proper prompt using RAG to retrieve relevant context.
        
        Args:
            user_input (str): The user's question about Boston schools

        Returns:
            str: A formatted prompt ready for the model
        """
        system_message = """You are a helpful and accurate school enrollment assistant for Boston Public Schools (BPS).
        You can provide information about school options, locations, programs, and other details
        to help families make informed decisions about their children's education.
        
        Provide clear, fact-based, and non-misleading information using the data provided below.
        Focus on answering only the user's specific question using the relevant school information.
        
        When answering questions about specific schools, neighborhoods, or programs, prioritize information 
        from the RETRIEVED_SCHOOLS section, which contains the most relevant schools for the user's query.

        DO NOT make up or hallucinate any school information.

        If the retrieved schools don't match what the user is looking for, acknowledge this limitation
        and suggest they contact BPS directly at (617) 635-9010 for more information.
        """

        age_cutoffs_section = SchoolChatbot.load_age_cutoffs()

        transportation_section = """# TRANSPORTATION_ELIGIBILITY
        - K0–K1: Bus eligible if >0.75 miles from school
        - K2–5: Bus eligible if >1 mile
        - Grades 6–8: Bus eligible if >1.5 miles
        - Grades 9–12: MBTA pass provided
        """
        
        # Instead of including all school data, retrieve relevant schools using RAG
        retrieved_docs = self.rag_engine.retrieve(user_input, top_k=3)
        retrieved_context = self.rag_engine.format_retrieved_context(retrieved_docs)
        
        # Comment out the full dataset reference to reduce token usage
        # school_data_section = SchoolChatbot.format_school_data(
        #     school_csv=self.school_csv,
        #     programs_csv=self.programs_csv,
        # )

        examples_section = """# EXAMPLES
            User: My child is turning 5 on August 15 and we live in 02124. What grade can they enter, and what schools are available?
            Assistant: Since your child turns 5 before September 1, they are eligible for K2. Based on your zip code (02124), eligible schools may include Joseph Lee K-8, Mildred Avenue, and TechBoston Academy.
            """

        # Combine all sections into the final prompt
        # f"{school_data_section}\n"  # Comment out the full dataset section
        prompt = (
            f"<|system|>\n{system_message}\n"
            f"{age_cutoffs_section}\n"
            f"{transportation_section}\n"
            f"{retrieved_context}\n"
            f"{examples_section}\n"
            f"<|user|>\n{user_input}\n<|assistant|>\n"
        )

        print(prompt)
        return prompt
    

        
    def get_response(self, user_input):
        """
        Generate responses to user questions using RAG and the language model.
        
        Args:
            user_input (str): The user's question about Boston schools

        Returns:
            str: The chatbot's response
        """
        prompt = self.format_prompt(user_input)
        
        # Generate response using the model
        response = self.client.text_generation(
            prompt,
            max_new_tokens=512,
            temperature=0.7,
            do_sample=True,
            repetition_penalty=1.1
        )
        
        return response