Spaces:
Sleeping
Sleeping
File size: 3,916 Bytes
fbdfc24 3e14b58 fbdfc24 478b91f fbdfc24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 | # [file name]: core/nodes/retrieval_nodes.py
# Add this as the FIRST lines of code (after docstrings)
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import logging
from typing import Dict, Any
from langchain_core.runnables import RunnableConfig
from models.state_models import MultiCountryLegalState
from core.retriever import LegalRetriever
logger = logging.getLogger(__name__)
class RetrievalNodes:
"""Scalable legal retrieval nodes for any number of countries"""
def __init__(self, country_retrievers: Dict[str, LegalRetriever]):
self.country_retrievers = country_retrievers
async def country_retrieval_node(self, state: MultiCountryLegalState, config: RunnableConfig, country_code: str) -> Dict[str, Any]:
"""Generic country retrieval for any country"""
try:
if country_code not in self.country_retrievers:
logger.error(f"❌ Country not configured: {country_code}")
return {
"search_results": f"Country {country_code} not available",
"detected_articles": [],
"supplemental_message": f"Pays {country_code} non configuré dans le système."
}
retriever = self.country_retrievers[country_code]
s = state.model_dump()
last_human = self._get_last_human_message(s.get("messages", []))
if not last_human:
return {
"search_results": f"No query for {country_code} retrieval",
"detected_articles": [],
"supplemental_message": "Aucune requête trouvée pour la recherche."
}
user_query = last_human.get("content", "").strip()
if not user_query:
return {
"search_results": f"Empty query for {country_code} retrieval",
"detected_articles": [],
"supplemental_message": "Requête vide pour la recherche."
}
logger.info(f"🌍 Performing {country_code} retrieval for: '{user_query[:50]}...'")
enhanced_docs, detected_articles, applied_filters, supplemental_message = await retriever.smart_legal_query(user_query, country_code)
search_results = retriever.format_search_results(
user_query, enhanced_docs, detected_articles, applied_filters, country_code, supplemental_message
)
logger.info(f"📚 Retrieved {len(enhanced_docs)} documents for {country_code}")
return {
"search_results": search_results,
"detected_articles": detected_articles,
"last_search_query": user_query,
"supplemental_message": supplemental_message, # Pass the supplemental message to state
# Store complex data in search_metadata instead of legal_context
"search_metadata": {
"applied_filters": applied_filters,
"documents_count": len(enhanced_docs),
"supplemental_message": supplemental_message
}
}
except Exception as e:
logger.error(f"Error in {country_code} retrieval: {str(e)}")
return {
"search_results": f"Erreur lors de la recherche {country_code}: {str(e)}",
"detected_articles": [],
"supplemental_message": f"Erreur lors de la recherche: {str(e)}"
}
def _get_last_human_message(self, messages: list) -> Dict[str, Any]:
"""Get the last human message"""
for msg in reversed(messages):
if msg.get("role") in ["user", "human"]:
return msg
return {} |