File size: 6,443 Bytes
3998131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
RAG Chain Module
Orchestrates retrieval and generation for legal explanations
"""

import logging
from typing import Dict, Any, List, Optional

from .embeddings import EmbeddingGenerator
from .llm_client import MistralClient
from .prompts import format_rag_prompt, LEGAL_SYSTEM_PROMPT
from .config import DEFAULT_RETRIEVAL_K, PINECONE_API_KEY

# Import Pinecone - required for RAG chain
try:
    from .pinecone_vector_db import PineconeLegalVectorDB
    PINECONE_AVAILABLE = True
except ImportError:
    PINECONE_AVAILABLE = False
    PineconeLegalVectorDB = None

logger = logging.getLogger(__name__)

# Set up file logging
def _setup_rag_logging():
    """Ensure RAG chain logs are written to file"""
    try:
        from .logging_setup import setup_logging
        setup_logging("module_a.rag_chain")
    except Exception:
        pass  # Fallback to default logging if setup fails

_setup_rag_logging()


class LegalRAGChain:
    """
    Retrieval-Augmented Generation Chain for Legal Explanations
    Combines Vector DB retrieval with Mistral LLM generation
    
    NOTE: This RAG chain uses Pinecone only. ChromaDB integration has been removed.
    Make sure PINECONE_API_KEY is set before initializing.
    """
    
    def __init__(self):
        """Initialize the RAG chain components"""
        logger.info("Initializing Legal RAG Chain...")
        
        # Check if Pinecone is available
        if not PINECONE_AVAILABLE:
            raise ImportError(
                "Pinecone client not installed. "
                "Install with: pip install pinecone-client[grpc]>=3.0.0"
            )
        
        # Check if API key is configured
        if not PINECONE_API_KEY:
            raise ValueError(
                "PINECONE_API_KEY must be set to use the RAG chain. "
                "Set it as an environment variable or in a .env file. "
                "Get your API key from: https://app.pinecone.io/"
            )
        
        # Initialize components
        self.embedder = EmbeddingGenerator()
        
        # Initialize Pinecone vector database
        logger.info("Initializing Pinecone vector database...")
        try:
            self.vector_db = PineconeLegalVectorDB()
            logger.info("✓ Using Pinecone cloud vector database")
        except Exception as e:
            logger.error(f"Failed to initialize Pinecone: {e}")
            raise RuntimeError(
                f"Pinecone initialization failed: {e}. "
                "Please check your API key and network connection. "
                "See module_a/PINECONE_SETUP.md for setup instructions."
            )
        
        self.llm = MistralClient()
        
        logger.info("RAG Chain initialized successfully with Pinecone")
    
    def get_vector_db_info(self) -> Dict[str, Any]:
        """
        Get information about the Pinecone vector database
        
        Returns:
            Dictionary with database type, name, and other info
        """
        info = {
            "type": "Pinecone",
            "class_name": type(self.vector_db).__name__,
            "is_pinecone": True,
            "index_name": getattr(self.vector_db, "index_name", "unknown"),
            "vector_count": self.vector_db.get_count()
        }
        
        return info
    
    def run(
        self, 
        query: str, 
        k: int = DEFAULT_RETRIEVAL_K
    ) -> Dict[str, Any]:
        """
        Run the full RAG pipeline
        
        Args:
            query: User's question
            k: Number of chunks to retrieve
            
        Returns:
            Dictionary with 'query', 'explanation', and 'sources'
        """
        logger.info(f"Processing query: {query}")
        
        # Step 1: Retrieve relevant chunks
        logger.info("Step 1: Retrieving relevant laws...")
        query_embedding = self.embedder.generate_embedding(query)
        retrieval_results = self.vector_db.query_with_embedding(
            query_embedding.tolist(), 
            n_results=k
        )
        
        # Process retrieval results into a clean list
        context_chunks = []
        if retrieval_results['documents'][0]:
            for doc, metadata, distance in zip(
                retrieval_results['documents'][0],
                retrieval_results['metadatas'][0],
                retrieval_results['distances'][0]
            ):
                context_chunks.append({
                    'text': doc,
                    'metadata': metadata,
                    'distance': distance
                })
        
        logger.info(f"Retrieved {len(context_chunks)} relevant chunks")
        
        # Step 2: Generate explanation
        logger.info("Step 2: Generating explanation...")
        
        # Format prompt
        prompt = format_rag_prompt(query, context_chunks)
        
        # Call LLM
        try:
            explanation = self.llm.generate_response(
                prompt=prompt,
                system_prompt=LEGAL_SYSTEM_PROMPT
            )
        except Exception as e:
            logger.error(f"Generation failed: {e}")
            explanation = "I apologize, but I encountered an error while generating the explanation. Please try again later."
        
        # Step 3: Format output with improved source handling
        sources = []
        for i, chunk in enumerate(context_chunks):
            source_file = chunk['metadata'].get('source_file', 'Legal Document')
            article_section = chunk['metadata'].get('article_section')

            # If no specific section, try to extract from the text
            if not article_section and 'Article' in chunk['text'][:200]:
                # Try to extract article number from beginning of text
                import re
                match = re.search(r'Article\s+(\d+[A-Za-z]?)', chunk['text'][:200])
                if match:
                    article_section = f"Article {match.group(1)}"

            # Create source entry
            source_entry = {
                'file': source_file,
                'section': article_section or f"Section {i+1}",
                'relevance_score': 1.0 - chunk['distance']  # Approx score
            }
            sources.append(source_entry)

        result = {
            'query': query,
            'explanation': explanation,
            'sources': sources
        }

        logger.info(f"Returning {len(sources)} sources")

        return result