File size: 8,012 Bytes
34345fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
import pandas as pd
import numpy as np
import faiss
import os
from typing import List, Dict, Tuple, Any, Optional
from sentence_transformers import SentenceTransformer
import pickle
import re

class SchoolDocument:
    """
    Represents a document containing information about a school,
    which can be used for retrieval and context building.
    """
    def __init__(self, 
                 school_name: str, 
                 content: str, 
                 metadata: Optional[Dict[str, Any]] = None):
        self.school_name = school_name
        self.content = content
        self.metadata = metadata or {}
    
    def __str__(self):
        return f"{self.school_name}: {self.content}"

class RAGEngine:
    """
    Retrieval-Augmented Generation engine for the Boston School Chatbot.
    This class handles the embedding, indexing, and retrieval of school information
    to provide context-relevant responses.
    """
    
    def __init__(self, embedding_model: str = "all-MiniLM-L6-v2"):
        """
        Initialize the RAG engine with a sentence transformer model for embeddings.
        
        Args:
            embedding_model: The HuggingFace model ID to use for embeddings
        """
        self.embedding_model = SentenceTransformer(embedding_model)
        self.documents = []
        self.embeddings = None
        self.faiss_index = None
        self.index_built = False
        
    def process_school_data(self, 
                           school_csv: str = 'BPS.csv',
                           programs_csv: str = 'BPS-special-programs.csv') -> List[SchoolDocument]:
        """
        Process school data from CSV files and create document objects.
        
        Args:
            school_csv: Path to the school data CSV
            programs_csv: Path to the special programs CSV
            
        Returns:
            List of SchoolDocument objects
        """
        # Load both datasets
        schools_df = pd.read_csv(school_csv)
        programs_df = pd.read_csv(programs_csv)
        
        # Merge on School Name
        merged_df = pd.merge(schools_df, programs_df, on="School Name", how="left")
        
        # Create documents for each school
        documents = []
        
        for _, row in merged_df.iterrows():
            school_name = row["School Name"]
            
            # Extract address components to identify the neighborhood
            address = row["Address"] if "Address" in row else ""
            zip_code = ""
            neighborhood = ""
            
            if address:
                # Extract zip code from address
                zip_match = re.search(r'MA\s+(\d{5})', address)
                if zip_match:
                    zip_code = zip_match.group(1)
                
                # Extract neighborhood from address
                neighborhood_match = re.search(r'([A-Za-z\s]+),\s+MA', address)
                if neighborhood_match:
                    neighborhood = neighborhood_match.group(1).strip()
            
            # Collect all programs marked "Yes"
            programs_offered = []
            for col in programs_df.columns[1:]:
                if row.get(col) == "Yes":
                    programs_offered.append(col)
            
            # Create the content string
            content = f"{school_name} is a {row.get('School Type', 'N/A')} school serving grades {row.get('Grades Served', 'N/A')}."
            
            if address:
                content += f" Located at {address}."
            
            if programs_offered:
                # Format program names to be more readable
                readable_programs = [p.replace('_', ' ').title() for p in programs_offered]
                content += f" Special programs include: {', '.join(readable_programs)}."
            
            # Create metadata
            metadata = {
                "grades": row.get("Grades Served", ""),
                "type": row.get("School Type", ""),
                "address": address,
                "zip_code": zip_code,
                "neighborhood": neighborhood,
                "programs": programs_offered,
                "phone": row.get("Phone Number", "") if "Phone Number" in row else "",
                "email": row.get("Email Address", "") if "Email Address" in row else ""
            }
            
            # Create document
            doc = SchoolDocument(school_name, content, metadata)
            documents.append(doc)
        
        self.documents = documents
        return documents
    
    def build_index(self, save_path: Optional[str] = None) -> None:
        """
        Build the FAISS index for fast similarity search.
        
        Args:
            save_path: Optional path to save the index and documents
        """
        if not self.documents:
            raise ValueError("No documents to index. Call process_school_data first.")
        
        # Create text chunks for embedding
        texts = [doc.content for doc in self.documents]
        
        # Generate embeddings
        embeddings = self.embedding_model.encode(texts)
        self.embeddings = embeddings.astype('float32')
        
        # Build the FAISS index
        dimension = self.embeddings.shape[1]
        self.faiss_index = faiss.IndexFlatL2(dimension)
        self.faiss_index.add(self.embeddings)
        self.index_built = True
        
        # Save if a path is provided
        if save_path:
            os.makedirs(os.path.dirname(save_path), exist_ok=True)
            with open(f"{save_path}_documents.pkl", "wb") as f:
                pickle.dump(self.documents, f)
            with open(f"{save_path}_embeddings.pkl", "wb") as f:
                pickle.dump(self.embeddings, f)
            faiss.write_index(self.faiss_index, f"{save_path}_faiss.index")
    
    def load_index(self, load_path: str) -> None:
        """
        Load a previously built index from disk.
        
        Args:
            load_path: Path prefix for the saved files
        """
        with open(f"{load_path}_documents.pkl", "rb") as f:
            self.documents = pickle.load(f)
        with open(f"{load_path}_embeddings.pkl", "rb") as f:
            self.embeddings = pickle.load(f)
        self.faiss_index = faiss.read_index(f"{load_path}_faiss.index")
        self.index_built = True
    
    def retrieve(self, query: str, top_k: int = 3) -> List[SchoolDocument]:
        """
        Retrieve the most relevant documents for a given query.
        
        Args:
            query: The user's query
            top_k: Number of documents to retrieve
            
        Returns:
            List of the most relevant school documents
        """
        if not self.index_built:
            raise ValueError("Index not built. Call build_index first.")
        
        # Encode the query
        query_embedding = self.embedding_model.encode([query])
        query_embedding = query_embedding.astype('float32')
        
        # Search the index
        distances, indices = self.faiss_index.search(query_embedding, top_k)
        
        # Return the relevant documents
        relevant_docs = [self.documents[idx] for idx in indices[0]]
        return relevant_docs
    
    def format_retrieved_context(self, docs: List[SchoolDocument]) -> str:
        """
        Format retrieved documents into a context string for the model.
        
        Args:
            docs: List of retrieved school documents
            
        Returns:
            Formatted context string
        """
        context = "# RETRIEVED_SCHOOLS\n"
        for i, doc in enumerate(docs, 1):
            context += f"{i}. {doc.content}\n"
            
            # Add metadata details that might be helpful
            if doc.metadata.get("phone"):
                context += f"   Phone: {doc.metadata['phone']}\n"
            if doc.metadata.get("email"):
                context += f"   Email: {doc.metadata['email']}\n"
            
        return context