# [file name]: core/nodes/retrieval_nodes.py # Add this as the FIRST lines of code (after docstrings) import sys from pathlib import Path sys.path.insert(0, str(Path(__file__).parent.parent)) import logging from typing import Dict, Any from langchain_core.runnables import RunnableConfig from models.state_models import MultiCountryLegalState from core.retriever import LegalRetriever logger = logging.getLogger(__name__) class RetrievalNodes: """Scalable legal retrieval nodes for any number of countries""" def __init__(self, country_retrievers: Dict[str, LegalRetriever]): self.country_retrievers = country_retrievers async def country_retrieval_node(self, state: MultiCountryLegalState, config: RunnableConfig, country_code: str) -> Dict[str, Any]: """Generic country retrieval for any country""" try: if country_code not in self.country_retrievers: logger.error(f"❌ Country not configured: {country_code}") return { "search_results": f"Country {country_code} not available", "detected_articles": [], "supplemental_message": f"Pays {country_code} non configuré dans le système." } retriever = self.country_retrievers[country_code] s = state.model_dump() last_human = self._get_last_human_message(s.get("messages", [])) if not last_human: return { "search_results": f"No query for {country_code} retrieval", "detected_articles": [], "supplemental_message": "Aucune requête trouvée pour la recherche." } user_query = last_human.get("content", "").strip() if not user_query: return { "search_results": f"Empty query for {country_code} retrieval", "detected_articles": [], "supplemental_message": "Requête vide pour la recherche." } logger.info(f"🌍 Performing {country_code} retrieval for: '{user_query[:50]}...'") enhanced_docs, detected_articles, applied_filters, supplemental_message = await retriever.smart_legal_query(user_query, country_code) search_results = retriever.format_search_results( user_query, enhanced_docs, detected_articles, applied_filters, country_code, supplemental_message ) logger.info(f"📚 Retrieved {len(enhanced_docs)} documents for {country_code}") return { "search_results": search_results, "detected_articles": detected_articles, "last_search_query": user_query, "supplemental_message": supplemental_message, # Pass the supplemental message to state # Store complex data in search_metadata instead of legal_context "search_metadata": { "applied_filters": applied_filters, "documents_count": len(enhanced_docs), "supplemental_message": supplemental_message } } except Exception as e: logger.error(f"Error in {country_code} retrieval: {str(e)}") return { "search_results": f"Erreur lors de la recherche {country_code}: {str(e)}", "detected_articles": [], "supplemental_message": f"Erreur lors de la recherche: {str(e)}" } def _get_last_human_message(self, messages: list) -> Dict[str, Any]: """Get the last human message""" for msg in reversed(messages): if msg.get("role") in ["user", "human"]: return msg return {}