Spaces:
Runtime error
Runtime error
| # api/database.py | |
| import faiss | |
| import numpy as np | |
| import gridfs | |
| import re | |
| from pymongo import MongoClient | |
| from sentence_transformers import SentenceTransformer | |
| from .config import mongo_uri, index_uri, MODEL_CACHE_DIR, EMBEDDING_MODEL_DEVICE | |
| import logging | |
| logger = logging.getLogger("database-bot") | |
| class DatabaseManager: | |
| def __init__(self): | |
| self.embedding_model = None | |
| self.index = None | |
| self.symptom_vectors = None | |
| self.symptom_docs = None | |
| # MongoDB connections | |
| self.client = None | |
| self.iclient = None | |
| self.symptom_client = None | |
| # Collections | |
| self.qa_collection = None | |
| self.index_collection = None | |
| self.symptom_col = None | |
| self.fs = None | |
| def initialize_embedding_model(self): | |
| """Initialize the SentenceTransformer model""" | |
| logger.info("[Embedder] 📥 Loading SentenceTransformer Model...") | |
| try: | |
| self.embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device=EMBEDDING_MODEL_DEVICE) | |
| self.embedding_model = self.embedding_model.half() # Reduce memory | |
| logger.info("✅ Model Loaded Successfully.") | |
| except Exception as e: | |
| logger.error(f"❌ Model Loading Failed: {e}") | |
| raise | |
| def initialize_mongodb(self): | |
| """Initialize MongoDB connections and collections""" | |
| # QA data | |
| self.client = MongoClient(mongo_uri) | |
| db = self.client["MedicalChatbotDB"] | |
| self.qa_collection = db["qa_data"] | |
| # FAISS Index data | |
| self.iclient = MongoClient(index_uri) | |
| idb = self.iclient["MedicalChatbotDB"] | |
| self.index_collection = idb["faiss_index_files"] | |
| # Symptom Diagnosis data | |
| self.symptom_client = MongoClient(mongo_uri) | |
| self.symptom_col = self.symptom_client["MedicalChatbotDB"]["symptom_diagnosis"] | |
| # GridFS for FAISS index | |
| self.fs = gridfs.GridFS(idb, collection="faiss_index_files") | |
| def load_faiss_index(self): | |
| """Lazy load FAISS index from GridFS""" | |
| if self.index is None: | |
| logger.info("[KB] ⏳ Loading FAISS index from GridFS...") | |
| existing_file = self.fs.find_one({"filename": "faiss_index.bin"}) | |
| if existing_file: | |
| stored_index_bytes = existing_file.read() | |
| index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8') | |
| self.index = faiss.deserialize_index(index_bytes_np) | |
| logger.info("[KB] ✅ FAISS Index Loaded") | |
| else: | |
| logger.error("[KB] ❌ FAISS index not found in GridFS.") | |
| return self.index | |
| def load_symptom_vectors(self): | |
| """Lazy load symptom vectors for diagnosis""" | |
| if self.symptom_vectors is None: | |
| all_docs = list(self.symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1, "prognosis": 1})) | |
| self.symptom_docs = all_docs | |
| self.symptom_vectors = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32) | |
| def get_embedding_model(self): | |
| """Get the embedding model""" | |
| if self.embedding_model is None: | |
| self.initialize_embedding_model() | |
| return self.embedding_model | |
| def get_qa_collection(self): | |
| """Get QA collection""" | |
| if self.qa_collection is None: | |
| self.initialize_mongodb() | |
| return self.qa_collection | |
| def get_symptom_collection(self): | |
| """Get symptom collection""" | |
| if self.symptom_col is None: | |
| self.initialize_mongodb() | |
| return self.symptom_col | |
| # Global database manager instance | |
| db_manager = DatabaseManager() | |