MultiCountryRAG / core /nodes /retrieval_nodes.py
SAAHMATHWORKS
dockerfile 3
478b91f
# [file name]: core/nodes/retrieval_nodes.py
# Add this as the FIRST lines of code (after docstrings)
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import logging
from typing import Dict, Any
from langchain_core.runnables import RunnableConfig
from models.state_models import MultiCountryLegalState
from core.retriever import LegalRetriever
logger = logging.getLogger(__name__)
class RetrievalNodes:
"""Scalable legal retrieval nodes for any number of countries"""
def __init__(self, country_retrievers: Dict[str, LegalRetriever]):
self.country_retrievers = country_retrievers
async def country_retrieval_node(self, state: MultiCountryLegalState, config: RunnableConfig, country_code: str) -> Dict[str, Any]:
"""Generic country retrieval for any country"""
try:
if country_code not in self.country_retrievers:
logger.error(f"❌ Country not configured: {country_code}")
return {
"search_results": f"Country {country_code} not available",
"detected_articles": [],
"supplemental_message": f"Pays {country_code} non configuré dans le système."
}
retriever = self.country_retrievers[country_code]
s = state.model_dump()
last_human = self._get_last_human_message(s.get("messages", []))
if not last_human:
return {
"search_results": f"No query for {country_code} retrieval",
"detected_articles": [],
"supplemental_message": "Aucune requête trouvée pour la recherche."
}
user_query = last_human.get("content", "").strip()
if not user_query:
return {
"search_results": f"Empty query for {country_code} retrieval",
"detected_articles": [],
"supplemental_message": "Requête vide pour la recherche."
}
logger.info(f"🌍 Performing {country_code} retrieval for: '{user_query[:50]}...'")
enhanced_docs, detected_articles, applied_filters, supplemental_message = await retriever.smart_legal_query(user_query, country_code)
search_results = retriever.format_search_results(
user_query, enhanced_docs, detected_articles, applied_filters, country_code, supplemental_message
)
logger.info(f"📚 Retrieved {len(enhanced_docs)} documents for {country_code}")
return {
"search_results": search_results,
"detected_articles": detected_articles,
"last_search_query": user_query,
"supplemental_message": supplemental_message, # Pass the supplemental message to state
# Store complex data in search_metadata instead of legal_context
"search_metadata": {
"applied_filters": applied_filters,
"documents_count": len(enhanced_docs),
"supplemental_message": supplemental_message
}
}
except Exception as e:
logger.error(f"Error in {country_code} retrieval: {str(e)}")
return {
"search_results": f"Erreur lors de la recherche {country_code}: {str(e)}",
"detected_articles": [],
"supplemental_message": f"Erreur lors de la recherche: {str(e)}"
}
def _get_last_human_message(self, messages: list) -> Dict[str, Any]:
"""Get the last human message"""
for msg in reversed(messages):
if msg.get("role") in ["user", "human"]:
return msg
return {}