|
|
""" |
|
|
RAG Chain Module |
|
|
Orchestrates retrieval and generation for legal explanations |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from typing import Dict, Any, List, Optional |
|
|
|
|
|
from .embeddings import EmbeddingGenerator |
|
|
from .llm_client import MistralClient |
|
|
from .prompts import format_rag_prompt, LEGAL_SYSTEM_PROMPT |
|
|
from .config import DEFAULT_RETRIEVAL_K, PINECONE_API_KEY |
|
|
|
|
|
|
|
|
try: |
|
|
from .pinecone_vector_db import PineconeLegalVectorDB |
|
|
PINECONE_AVAILABLE = True |
|
|
except ImportError: |
|
|
PINECONE_AVAILABLE = False |
|
|
PineconeLegalVectorDB = None |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def _setup_rag_logging(): |
|
|
"""Ensure RAG chain logs are written to file""" |
|
|
try: |
|
|
from .logging_setup import setup_logging |
|
|
setup_logging("module_a.rag_chain") |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
_setup_rag_logging() |
|
|
|
|
|
|
|
|
class LegalRAGChain: |
|
|
""" |
|
|
Retrieval-Augmented Generation Chain for Legal Explanations |
|
|
Combines Vector DB retrieval with Mistral LLM generation |
|
|
|
|
|
NOTE: This RAG chain uses Pinecone only. ChromaDB integration has been removed. |
|
|
Make sure PINECONE_API_KEY is set before initializing. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
"""Initialize the RAG chain components""" |
|
|
logger.info("Initializing Legal RAG Chain...") |
|
|
|
|
|
|
|
|
if not PINECONE_AVAILABLE: |
|
|
raise ImportError( |
|
|
"Pinecone client not installed. " |
|
|
"Install with: pip install pinecone-client[grpc]>=3.0.0" |
|
|
) |
|
|
|
|
|
|
|
|
if not PINECONE_API_KEY: |
|
|
raise ValueError( |
|
|
"PINECONE_API_KEY must be set to use the RAG chain. " |
|
|
"Set it as an environment variable or in a .env file. " |
|
|
"Get your API key from: https://app.pinecone.io/" |
|
|
) |
|
|
|
|
|
|
|
|
self.embedder = EmbeddingGenerator() |
|
|
|
|
|
|
|
|
logger.info("Initializing Pinecone vector database...") |
|
|
try: |
|
|
self.vector_db = PineconeLegalVectorDB() |
|
|
logger.info("✓ Using Pinecone cloud vector database") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize Pinecone: {e}") |
|
|
raise RuntimeError( |
|
|
f"Pinecone initialization failed: {e}. " |
|
|
"Please check your API key and network connection. " |
|
|
"See module_a/PINECONE_SETUP.md for setup instructions." |
|
|
) |
|
|
|
|
|
self.llm = MistralClient() |
|
|
|
|
|
logger.info("RAG Chain initialized successfully with Pinecone") |
|
|
|
|
|
def get_vector_db_info(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get information about the Pinecone vector database |
|
|
|
|
|
Returns: |
|
|
Dictionary with database type, name, and other info |
|
|
""" |
|
|
info = { |
|
|
"type": "Pinecone", |
|
|
"class_name": type(self.vector_db).__name__, |
|
|
"is_pinecone": True, |
|
|
"index_name": getattr(self.vector_db, "index_name", "unknown"), |
|
|
"vector_count": self.vector_db.get_count() |
|
|
} |
|
|
|
|
|
return info |
|
|
|
|
|
def run( |
|
|
self, |
|
|
query: str, |
|
|
k: int = DEFAULT_RETRIEVAL_K |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Run the full RAG pipeline |
|
|
|
|
|
Args: |
|
|
query: User's question |
|
|
k: Number of chunks to retrieve |
|
|
|
|
|
Returns: |
|
|
Dictionary with 'query', 'explanation', and 'sources' |
|
|
""" |
|
|
logger.info(f"Processing query: {query}") |
|
|
|
|
|
|
|
|
logger.info("Step 1: Retrieving relevant laws...") |
|
|
query_embedding = self.embedder.generate_embedding(query) |
|
|
retrieval_results = self.vector_db.query_with_embedding( |
|
|
query_embedding.tolist(), |
|
|
n_results=k |
|
|
) |
|
|
|
|
|
|
|
|
context_chunks = [] |
|
|
if retrieval_results['documents'][0]: |
|
|
for doc, metadata, distance in zip( |
|
|
retrieval_results['documents'][0], |
|
|
retrieval_results['metadatas'][0], |
|
|
retrieval_results['distances'][0] |
|
|
): |
|
|
context_chunks.append({ |
|
|
'text': doc, |
|
|
'metadata': metadata, |
|
|
'distance': distance |
|
|
}) |
|
|
|
|
|
logger.info(f"Retrieved {len(context_chunks)} relevant chunks") |
|
|
|
|
|
|
|
|
logger.info("Step 2: Generating explanation...") |
|
|
|
|
|
|
|
|
prompt = format_rag_prompt(query, context_chunks) |
|
|
|
|
|
|
|
|
try: |
|
|
explanation = self.llm.generate_response( |
|
|
prompt=prompt, |
|
|
system_prompt=LEGAL_SYSTEM_PROMPT |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Generation failed: {e}") |
|
|
explanation = "I apologize, but I encountered an error while generating the explanation. Please try again later." |
|
|
|
|
|
|
|
|
sources = [] |
|
|
for i, chunk in enumerate(context_chunks): |
|
|
source_file = chunk['metadata'].get('source_file', 'Legal Document') |
|
|
article_section = chunk['metadata'].get('article_section') |
|
|
|
|
|
|
|
|
if not article_section and 'Article' in chunk['text'][:200]: |
|
|
|
|
|
import re |
|
|
match = re.search(r'Article\s+(\d+[A-Za-z]?)', chunk['text'][:200]) |
|
|
if match: |
|
|
article_section = f"Article {match.group(1)}" |
|
|
|
|
|
|
|
|
source_entry = { |
|
|
'file': source_file, |
|
|
'section': article_section or f"Section {i+1}", |
|
|
'relevance_score': 1.0 - chunk['distance'] |
|
|
} |
|
|
sources.append(source_entry) |
|
|
|
|
|
result = { |
|
|
'query': query, |
|
|
'explanation': explanation, |
|
|
'sources': sources |
|
|
} |
|
|
|
|
|
logger.info(f"Returning {len(sources)} sources") |
|
|
|
|
|
return result |
|
|
|