Spaces:
Running
Running
| # src/rag_manager.py | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| class SociolinguisticRAG: | |
| def __init__(self): | |
| print("🧠 Initializing Local Embedding Model (all-MiniLM-L6-v2)...") | |
| # Downloads tiny ~80MB model to your server once | |
| self.encoder = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.embedding_dim = self.encoder.get_sentence_embedding_dimension() | |
| # Initialize an empty in-memory FAISS index (L2 distance / Cosine Similarity) | |
| self.index = faiss.IndexFlatL2(self.embedding_dim) | |
| self.chunks = [] | |
| def load_persona_rules(self, dialect_name: str, rules_list: list): | |
| """ | |
| Vectorizes and loads a list of sociolinguistic rules into RAM. | |
| rules_list should be a list of strings (e.g., ["Rule 1...", "Rule 2..."]) | |
| """ | |
| if not rules_list: | |
| return | |
| print(f"📚 Vectorizing {len(rules_list)} rules for {dialect_name}...") | |
| # Store text chunks | |
| self.chunks = rules_list | |
| # Convert text rules to mathematical vectors | |
| embeddings = self.encoder.encode(self.chunks) | |
| # FAISS requires float32 numpy arrays | |
| embeddings = np.array(embeddings).astype('float32') | |
| # Add to the in-memory database | |
| self.index.add(embeddings) | |
| print(f"✅ {dialect_name} successfully vectorized and indexed in RAM.") | |
| def retrieve_context(self, user_transcription: str, k=3) -> str: | |
| """ | |
| Takes the Whisper output and finds the top 'k' most relevant cultural rules. | |
| """ | |
| if not self.chunks: | |
| return "No specific rules loaded for this dialect." | |
| # 1. Convert user's spoken sentence into a vector | |
| query_vector = self.encoder.encode([user_transcription]) | |
| query_vector = np.array(query_vector).astype('float32') | |
| # 2. Search FAISS for the closest mathematical matches | |
| distances, indices = self.index.search(query_vector, k) | |
| # 3. Retrieve the actual text for those rules (with safety check) | |
| retrieved_rules = [self.chunks[i] for i in indices[0] if i < len(self.chunks)] | |
| # 4. Format them into a neat string for Gemini | |
| context_string = "\n".join([f"- {rule}" for rule in retrieved_rules]) | |
| return context_string | |
| # ========================================== | |
| # EXAMPLE USAGE / LOCAL TESTING | |
| # ========================================== | |
| if __name__ == "__main__": | |
| # 1. Your raw JSON data broken into a Python list | |
| nigerian_pidgin_rules = [ | |
| "Lexicon: 'Wahala' means trouble or problem.", | |
| "Lexicon: 'How far' is a greeting meaning 'How are you' or 'What's going on'.", | |
| "Pragmatics: Repeating a word (e.g., 'now now') emphasizes extreme urgency.", | |
| "Syntax: Pluralization is often done by adding 'dem' after a noun." | |
| ] | |
| # 2. Initialize and load (Happens on App Startup) | |
| rag = SociolinguisticRAG() | |
| rag.load_persona_rules("Nigerian Pidgin", nigerian_pidgin_rules) | |
| # 3. Retrieve (Happens when the user speaks) | |
| whisper_text = "What's going on now?" | |
| print(f"\n🗣️ User said: {whisper_text}") | |
| relevant_context = rag.retrieve_context(whisper_text, k=2) | |
| print(f"🎯 Retrieved Context for Gemini:\n{relevant_context}") |