File size: 3,916 Bytes
fbdfc24
3e14b58
 
 
 
 
fbdfc24
 
 
 
478b91f
 
fbdfc24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# [file name]: core/nodes/retrieval_nodes.py
# Add this as the FIRST lines of code (after docstrings)
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))

import logging
from typing import Dict, Any
from langchain_core.runnables import RunnableConfig

from models.state_models import MultiCountryLegalState
from core.retriever import LegalRetriever

logger = logging.getLogger(__name__)

class RetrievalNodes:
    """Scalable legal retrieval nodes for any number of countries"""
    
    def __init__(self, country_retrievers: Dict[str, LegalRetriever]):
        self.country_retrievers = country_retrievers
    
    async def country_retrieval_node(self, state: MultiCountryLegalState, config: RunnableConfig, country_code: str) -> Dict[str, Any]:
        """Generic country retrieval for any country"""
        try:
            if country_code not in self.country_retrievers:
                logger.error(f"❌ Country not configured: {country_code}")
                return {
                    "search_results": f"Country {country_code} not available",
                    "detected_articles": [],
                    "supplemental_message": f"Pays {country_code} non configuré dans le système."
                }
            
            retriever = self.country_retrievers[country_code]
            s = state.model_dump()
            last_human = self._get_last_human_message(s.get("messages", []))
            
            if not last_human:
                return {
                    "search_results": f"No query for {country_code} retrieval", 
                    "detected_articles": [],
                    "supplemental_message": "Aucune requête trouvée pour la recherche."
                }

            user_query = last_human.get("content", "").strip()
            if not user_query:
                return {
                    "search_results": f"Empty query for {country_code} retrieval", 
                    "detected_articles": [],
                    "supplemental_message": "Requête vide pour la recherche."
                }
            
            logger.info(f"🌍 Performing {country_code} retrieval for: '{user_query[:50]}...'")
            
            enhanced_docs, detected_articles, applied_filters, supplemental_message = await retriever.smart_legal_query(user_query, country_code)
            
            search_results = retriever.format_search_results(
                user_query, enhanced_docs, detected_articles, applied_filters, country_code, supplemental_message
            )
            
            logger.info(f"📚 Retrieved {len(enhanced_docs)} documents for {country_code}")
            
            return {
                "search_results": search_results,
                "detected_articles": detected_articles,
                "last_search_query": user_query,
                "supplemental_message": supplemental_message,  # Pass the supplemental message to state
                # Store complex data in search_metadata instead of legal_context
                "search_metadata": {
                    "applied_filters": applied_filters,
                    "documents_count": len(enhanced_docs),
                    "supplemental_message": supplemental_message
                }
            }
            
        except Exception as e:
            logger.error(f"Error in {country_code} retrieval: {str(e)}")
            return {
                "search_results": f"Erreur lors de la recherche {country_code}: {str(e)}",
                "detected_articles": [],
                "supplemental_message": f"Erreur lors de la recherche: {str(e)}"
            }
    
    def _get_last_human_message(self, messages: list) -> Dict[str, Any]:
        """Get the last human message"""
        for msg in reversed(messages):
            if msg.get("role") in ["user", "human"]:
                return msg
        return {}