Spaces:
Runtime error
Runtime error
File size: 3,755 Bytes
b8bf5c8 f7d86e6 b8bf5c8 2fb34cd b8bf5c8 64e1be9 b8bf5c8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | # api/database.py
import faiss
import numpy as np
import gridfs
import re
from pymongo import MongoClient
from sentence_transformers import SentenceTransformer
from .config import mongo_uri, index_uri, MODEL_CACHE_DIR, EMBEDDING_MODEL_DEVICE
import logging
logger = logging.getLogger("database-bot")
class DatabaseManager:
def __init__(self):
self.embedding_model = None
self.index = None
self.symptom_vectors = None
self.symptom_docs = None
# MongoDB connections
self.client = None
self.iclient = None
self.symptom_client = None
# Collections
self.qa_collection = None
self.index_collection = None
self.symptom_col = None
self.fs = None
def initialize_embedding_model(self):
"""Initialize the SentenceTransformer model"""
logger.info("[Embedder] 📥 Loading SentenceTransformer Model...")
try:
self.embedding_model = SentenceTransformer(MODEL_CACHE_DIR, device=EMBEDDING_MODEL_DEVICE)
self.embedding_model = self.embedding_model.half() # Reduce memory
logger.info("✅ Model Loaded Successfully.")
except Exception as e:
logger.error(f"❌ Model Loading Failed: {e}")
raise
def initialize_mongodb(self):
"""Initialize MongoDB connections and collections"""
# QA data
self.client = MongoClient(mongo_uri)
db = self.client["MedicalChatbotDB"]
self.qa_collection = db["qa_data"]
# FAISS Index data
self.iclient = MongoClient(index_uri)
idb = self.iclient["MedicalChatbotDB"]
self.index_collection = idb["faiss_index_files"]
# Symptom Diagnosis data
self.symptom_client = MongoClient(mongo_uri)
self.symptom_col = self.symptom_client["MedicalChatbotDB"]["symptom_diagnosis"]
# GridFS for FAISS index
self.fs = gridfs.GridFS(idb, collection="faiss_index_files")
def load_faiss_index(self):
"""Lazy load FAISS index from GridFS"""
if self.index is None:
logger.info("[KB] ⏳ Loading FAISS index from GridFS...")
existing_file = self.fs.find_one({"filename": "faiss_index.bin"})
if existing_file:
stored_index_bytes = existing_file.read()
index_bytes_np = np.frombuffer(stored_index_bytes, dtype='uint8')
self.index = faiss.deserialize_index(index_bytes_np)
logger.info("[KB] ✅ FAISS Index Loaded")
else:
logger.error("[KB] ❌ FAISS index not found in GridFS.")
return self.index
def load_symptom_vectors(self):
"""Lazy load symptom vectors for diagnosis"""
if self.symptom_vectors is None:
all_docs = list(self.symptom_col.find({}, {"embedding": 1, "answer": 1, "question": 1, "prognosis": 1}))
self.symptom_docs = all_docs
self.symptom_vectors = np.array([doc["embedding"] for doc in all_docs], dtype=np.float32)
def get_embedding_model(self):
"""Get the embedding model"""
if self.embedding_model is None:
self.initialize_embedding_model()
return self.embedding_model
def get_qa_collection(self):
"""Get QA collection"""
if self.qa_collection is None:
self.initialize_mongodb()
return self.qa_collection
def get_symptom_collection(self):
"""Get symptom collection"""
if self.symptom_col is None:
self.initialize_mongodb()
return self.symptom_col
# Global database manager instance
db_manager = DatabaseManager()
|