File size: 6,443 Bytes
3998131 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
"""
RAG Chain Module
Orchestrates retrieval and generation for legal explanations
"""
import logging
from typing import Dict, Any, List, Optional
from .embeddings import EmbeddingGenerator
from .llm_client import MistralClient
from .prompts import format_rag_prompt, LEGAL_SYSTEM_PROMPT
from .config import DEFAULT_RETRIEVAL_K, PINECONE_API_KEY
# Import Pinecone - required for RAG chain
try:
from .pinecone_vector_db import PineconeLegalVectorDB
PINECONE_AVAILABLE = True
except ImportError:
PINECONE_AVAILABLE = False
PineconeLegalVectorDB = None
logger = logging.getLogger(__name__)
# Set up file logging
def _setup_rag_logging():
"""Ensure RAG chain logs are written to file"""
try:
from .logging_setup import setup_logging
setup_logging("module_a.rag_chain")
except Exception:
pass # Fallback to default logging if setup fails
_setup_rag_logging()
class LegalRAGChain:
"""
Retrieval-Augmented Generation Chain for Legal Explanations
Combines Vector DB retrieval with Mistral LLM generation
NOTE: This RAG chain uses Pinecone only. ChromaDB integration has been removed.
Make sure PINECONE_API_KEY is set before initializing.
"""
def __init__(self):
"""Initialize the RAG chain components"""
logger.info("Initializing Legal RAG Chain...")
# Check if Pinecone is available
if not PINECONE_AVAILABLE:
raise ImportError(
"Pinecone client not installed. "
"Install with: pip install pinecone-client[grpc]>=3.0.0"
)
# Check if API key is configured
if not PINECONE_API_KEY:
raise ValueError(
"PINECONE_API_KEY must be set to use the RAG chain. "
"Set it as an environment variable or in a .env file. "
"Get your API key from: https://app.pinecone.io/"
)
# Initialize components
self.embedder = EmbeddingGenerator()
# Initialize Pinecone vector database
logger.info("Initializing Pinecone vector database...")
try:
self.vector_db = PineconeLegalVectorDB()
logger.info("✓ Using Pinecone cloud vector database")
except Exception as e:
logger.error(f"Failed to initialize Pinecone: {e}")
raise RuntimeError(
f"Pinecone initialization failed: {e}. "
"Please check your API key and network connection. "
"See module_a/PINECONE_SETUP.md for setup instructions."
)
self.llm = MistralClient()
logger.info("RAG Chain initialized successfully with Pinecone")
def get_vector_db_info(self) -> Dict[str, Any]:
"""
Get information about the Pinecone vector database
Returns:
Dictionary with database type, name, and other info
"""
info = {
"type": "Pinecone",
"class_name": type(self.vector_db).__name__,
"is_pinecone": True,
"index_name": getattr(self.vector_db, "index_name", "unknown"),
"vector_count": self.vector_db.get_count()
}
return info
def run(
self,
query: str,
k: int = DEFAULT_RETRIEVAL_K
) -> Dict[str, Any]:
"""
Run the full RAG pipeline
Args:
query: User's question
k: Number of chunks to retrieve
Returns:
Dictionary with 'query', 'explanation', and 'sources'
"""
logger.info(f"Processing query: {query}")
# Step 1: Retrieve relevant chunks
logger.info("Step 1: Retrieving relevant laws...")
query_embedding = self.embedder.generate_embedding(query)
retrieval_results = self.vector_db.query_with_embedding(
query_embedding.tolist(),
n_results=k
)
# Process retrieval results into a clean list
context_chunks = []
if retrieval_results['documents'][0]:
for doc, metadata, distance in zip(
retrieval_results['documents'][0],
retrieval_results['metadatas'][0],
retrieval_results['distances'][0]
):
context_chunks.append({
'text': doc,
'metadata': metadata,
'distance': distance
})
logger.info(f"Retrieved {len(context_chunks)} relevant chunks")
# Step 2: Generate explanation
logger.info("Step 2: Generating explanation...")
# Format prompt
prompt = format_rag_prompt(query, context_chunks)
# Call LLM
try:
explanation = self.llm.generate_response(
prompt=prompt,
system_prompt=LEGAL_SYSTEM_PROMPT
)
except Exception as e:
logger.error(f"Generation failed: {e}")
explanation = "I apologize, but I encountered an error while generating the explanation. Please try again later."
# Step 3: Format output with improved source handling
sources = []
for i, chunk in enumerate(context_chunks):
source_file = chunk['metadata'].get('source_file', 'Legal Document')
article_section = chunk['metadata'].get('article_section')
# If no specific section, try to extract from the text
if not article_section and 'Article' in chunk['text'][:200]:
# Try to extract article number from beginning of text
import re
match = re.search(r'Article\s+(\d+[A-Za-z]?)', chunk['text'][:200])
if match:
article_section = f"Article {match.group(1)}"
# Create source entry
source_entry = {
'file': source_file,
'section': article_section or f"Section {i+1}",
'relevance_score': 1.0 - chunk['distance'] # Approx score
}
sources.append(source_entry)
result = {
'query': query,
'explanation': explanation,
'sources': sources
}
logger.info(f"Returning {len(sources)} sources")
return result
|