MultiCountryRAG / core /system_initializer.py
SAAHMATHWORKS
ready for hugging face space
f37bf1d
# [file name]: core/system_initializer.py
# Add this as the FIRST lines of code (after docstrings)
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import logging
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
from langgraph.checkpoint.memory import InMemorySaver
from core.graph_builder import GraphBuilder
from core.chat_manager import LegalChatManager
from core.router import CountryRouter
from database.mongodb_client import MongoDBClient
from database.postgres_checkpointer import PostgresCheckpointer
from langchain_openai import ChatOpenAI
from config import settings
logger = logging.getLogger(__name__)
async def setup_system():
"""Initialize the legal assistant system with fallback to in-memory checkpointer"""
try:
# 1. Initialize MongoDB using your existing class
mongo_client = MongoDBClient()
if not mongo_client.connect():
raise Exception("MongoDB connection failed")
logger.info("βœ… MongoDB connected successfully")
# 2. Use your existing vector stores directly from the client
vector_store_benin = mongo_client.benin_vectorstore
collection_benin = mongo_client.benin_collection
vector_store_madagascar = mongo_client.madagascar_vectorstore
collection_madagascar = mongo_client.madagascar_collection
# 3. Initialize retrievers
from core.retriever import LegalRetriever
benin_retriever = LegalRetriever(vector_store_benin, collection_benin)
madagascar_retriever = LegalRetriever(vector_store_madagascar, collection_madagascar)
country_retrievers = {
"benin": benin_retriever,
"madagascar": madagascar_retriever
}
# 4. Initialize LLM and router
llm = ChatOpenAI(
model="gpt-4o-mini",
temperature=0.1,
max_tokens=2000,
streaming=True
)
router = CountryRouter()
# 5. Initialize checkpointer with fallback logic
checkpointer = await _initialize_checkpointer_with_fallback()
# 6. Build graph
graph_builder = GraphBuilder(
router=router,
llm=llm,
checkpointer=checkpointer,
country_retrievers=country_retrievers
)
workflow = graph_builder.build_graph()
app = workflow.compile(checkpointer=checkpointer)
# 7. Initialize chat manager
chat_manager = LegalChatManager(app, checkpointer)
logger.info("βœ… API System initialized successfully")
return {
"chat_manager": chat_manager,
"graph": app,
"checkpointer": checkpointer
}
except Exception as e:
logger.error(f"❌ Failed to initialize system: {e}")
raise
async def _initialize_checkpointer_with_fallback():
"""Initialize checkpointer with fallback to in-memory if PostgreSQL fails"""
# First, try to initialize PostgreSQL checkpointer
postgres_checkpointer = None
database_url = getattr(settings, 'DATABASE_URL', None)
if not database_url:
# Try alternative setting names
database_url = getattr(settings, 'POSTGRES_URL', None) or \
getattr(settings, 'POSTGRESQL_URL', None) or \
getattr(settings, 'DB_URL', None)
if database_url:
try:
logger.info(f"πŸ”— Attempting PostgreSQL connection: {database_url.split('@')[-1] if '@' in database_url else 'local'}")
postgres_checkpointer = PostgresCheckpointer(
database_url=database_url,
max_connections=10,
min_connections=2
)
if await postgres_checkpointer.initialize():
checkpointer = postgres_checkpointer.get_checkpointer()
logger.info("βœ… PostgreSQL checkpointer initialized successfully")
return checkpointer
else:
logger.warning("❌ PostgreSQL checkpointer initialization failed, will fall back to in-memory")
except Exception as e:
logger.warning(f"❌ PostgreSQL connection failed: {e}, falling back to in-memory checkpointer")
else:
logger.warning("❌ No database URL found in settings, using in-memory checkpointer")
# Fall back to in-memory checkpointer
try:
checkpointer = InMemorySaver()
logger.info("βœ… In-memory checkpointer initialized as fallback")
logger.warning("⚠️ Using in-memory checkpointer - conversation history will not persist across restarts")
return checkpointer
except Exception as e:
logger.error(f"❌ Even in-memory checkpointer failed: {e}")
raise Exception("Failed to initialize any checkpointer")
def get_checkpointer_type(checkpointer):
"""Utility function to check what type of checkpointer is being used"""
if hasattr(checkpointer, '__class__'):
if 'PostgresSaver' in checkpointer.__class__.__name__:
return "postgres"
elif 'InMemorySaver' in checkpointer.__class__.__name__:
return "memory"
return "unknown"