SQuAD / utils /db.py
tnp554's picture
feat: deploy SQuAD backend with all AI models
09daf0b
"""
db.py — MongoDB Atlas connection with mongomock fallback.
If MONGO_URI is not set or the connection fails, the app runs on an
in-memory mock store so development works without any database.
"""
import os
import logging
from dotenv import load_dotenv
load_dotenv()
logger = logging.getLogger(__name__)
MONGO_URI = os.getenv("MONGO_URI") or os.getenv("MONGODB_URI") or ""
DB_NAME = "squad_qa"
_client = None
_db = None
_using_mock = False
def _connect_atlas():
"""Attempt to connect to MongoDB Atlas (or local Mongo)."""
global _client, _db, _using_mock
try:
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure, ConfigurationError, ServerSelectionTimeoutError
if not MONGO_URI or "username:password" in MONGO_URI:
raise ValueError("MONGO_URI not configured — falling back to mock.")
_client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000, tls=True, tlsAllowInvalidCertificates=True)
# Trigger actual connection check
_client.admin.command("ping")
_db = _client[DB_NAME]
_using_mock = False
logger.info("[DB] Connected to MongoDB Atlas successfully.")
except Exception as exc:
logger.warning(f"[DB] MongoDB connection failed: {exc}")
logger.warning("[DB] Falling back to in-memory mongomock.")
_connect_mock()
def _connect_mock():
"""Fall back to mongomock (in-memory, no persistence)."""
global _client, _db, _using_mock
try:
import mongomock
_client = mongomock.MongoClient()
_db = _client[DB_NAME]
_using_mock = True
logger.warning("[DB] Running on mongomock — data will NOT persist across restarts.")
except ImportError:
logger.error("[DB] mongomock not installed. Database unavailable.")
_db = None
def get_db():
"""Return the active database handle (Atlas or mock)."""
global _db
if _db is None:
_connect_atlas()
return _db
def is_using_mock():
return _using_mock
# Initialise on import
_connect_atlas()
# Convenience collection accessors
def users_col():
return get_db()["users"]
def chats_col():
return get_db()["chats"]
def settings_col():
return get_db()["settings"]