Legal-assistant / process_aware_rag.py
Rivalcoder
Add Files
a2b02a5
raw
history blame
8.6 kB
# process_aware_rag.py
import chromadb
import google.generativeai as genai
from query_classifier import QueryClassifier
from graph_builder import LegalProcessGraph
import os
from typing import Dict, List, Any
class ProcessAwareRAG:
def __init__(self):
# Initialize components
self.classifier = QueryClassifier()
self.legal_graph = LegalProcessGraph()
self.legal_graph.load_graph('legal_processes.pkl')
# Initialize vector store (use writable path by default)
chroma_path = os.getenv('CHROMA_DB_PATH', '/tmp/legal_vector_db')
os.makedirs(chroma_path, exist_ok=True)
client = chromadb.PersistentClient(path=chroma_path)
# Ensure collection exists
try:
self.vector_collection = client.get_collection("legal_context")
except Exception:
self.vector_collection = client.create_collection("legal_context")
# Initialize LLM
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
self.llm = genai.GenerativeModel('gemini-1.5-flash')
def retrieve_graph_context(self, classification: Dict) -> Dict:
"""Retrieve relevant context from knowledge graph"""
graph_context = {
'current_step': None,
'next_steps': [],
'resources': [],
'process_overview': None
}
if classification['process'] == 'general':
return graph_context
process_name = classification['process_name']
# Find current step or process start
if classification['step']:
current_step_id = classification['step']
else:
current_step_id = self.legal_graph.find_process_start(process_name)
if current_step_id:
# Get current step info
if current_step_id in self.legal_graph.graph.nodes:
graph_context['current_step'] = {
'id': current_step_id,
**self.legal_graph.graph.nodes[current_step_id]
}
# Get next steps
graph_context['next_steps'] = self.legal_graph.get_next_steps(current_step_id)
# Get relevant resources
graph_context['resources'] = self.legal_graph.get_required_resources(current_step_id)
return graph_context
def retrieve_vector_context(self, user_query: str, classification: Dict) -> List[Dict]:
"""Retrieve relevant context from vector store"""
# Query vector store
results = self.vector_collection.query(
query_texts=[user_query],
n_results=3,
where={"process": classification.get('process_name', '')} if classification.get('process_name') else None
)
vector_context = []
if results['documents']:
for i in range(len(results['documents'][0])):
vector_context.append({
'content': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i] if 'distances' in results else None
})
return vector_context
def generate_response(self, user_query: str, graph_context: Dict, vector_context: List[Dict], classification: Dict) -> str:
"""Generate comprehensive response using LLM"""
# Build structured prompt
system_prompt = """
You are a helpful, empathetic, and precise legal guide for Indian legal processes.
IMPORTANT GUIDELINES:
- You must NEVER give legal advice, only guide users through official processes
- Always stress that users should consult a qualified lawyer for their specific case
- Provide specific, actionable steps when possible
- Include official links, phone numbers, and government portals when available
- Be empathetic and understanding of the user's situation
- Use clear, simple language avoiding legal jargon
"""
# Prepare context information
context_sections = []
# Add graph context
if graph_context['current_step']:
context_sections.append(f"""
CURRENT PROCESS STEP:
Title: {graph_context['current_step']['title']}
Description: {graph_context['current_step']['description']}
Properties: {graph_context['current_step'].get('properties', {})}
""")
if graph_context['next_steps']:
next_steps_text = "\n".join([
f"- {step['title']}: {step['description']}"
for step in graph_context['next_steps']
])
context_sections.append(f"""
NEXT STEPS:
{next_steps_text}
""")
if graph_context['resources']:
resources_text = "\n".join([
f"- {res['title']} ({res['type']}): {res['properties'].get('url', res['properties'].get('phone', 'Contact available'))}"
for res in graph_context['resources']
])
context_sections.append(f"""
RELEVANT RESOURCES:
{resources_text}
""")
# Add vector context
if vector_context:
vector_text = "\n\n".join([doc['content'] for doc in vector_context])
context_sections.append(f"""
ADDITIONAL CONTEXT:
{vector_text}
""")
# Build final prompt
full_prompt = f"""
{system_prompt}
USER QUERY: "{user_query}"
CLASSIFICATION: Process: {classification.get('process_name', 'General')}, Intent: {classification.get('intent', 'information')}
{chr(10).join(context_sections)}
Please provide a helpful, structured response that:
1. Acknowledges the user's situation empathetically
2. Provides specific next steps if this is a process guidance request
3. Includes relevant official links and contact information
4. Reminds the user to consult a lawyer for specific legal advice
5. Uses bullet points, bold formatting, and clear structure
Format your response with clear sections and actionable information.
"""
try:
response = self.llm.generate_content(full_prompt)
return response.text
except Exception as e:
return f"I apologize, but I'm having trouble generating a response right now. Please try again or contact NALSA directly at nalsa-dla@nic.in for legal aid queries. Error: {str(e)}"
def process_query(self, user_query: str) -> Dict[str, Any]:
"""Main pipeline: process user query end-to-end"""
# Step 1: Classify query
classification = self.classifier.classify_query(user_query)
# Step 2: Retrieve graph context
graph_context = self.retrieve_graph_context(classification)
# Step 3: Retrieve vector context
vector_context = self.retrieve_vector_context(user_query, classification)
# Step 4: Generate response
response = self.generate_response(user_query, graph_context, vector_context, classification)
return {
'response': response,
'classification': classification,
'graph_context': graph_context,
'vector_context': vector_context,
'debug_info': {
'graph_nodes_found': len(graph_context.get('next_steps', [])),
'vector_docs_found': len(vector_context),
'process_identified': classification.get('process_name')
}
}
# Test the complete pipeline
if __name__ == "__main__":
rag_system = ProcessAwareRAG()
test_query = "I need free legal help but I'm not sure if I qualify. My monthly income is around 45000 rupees."
result = rag_system.process_query(test_query)
print("=== QUERY ===")
print(test_query)
print("\n=== RESPONSE ===")
print(result['response'])
print("\n=== DEBUG INFO ===")
print(f"Process: {result['debug_info']['process_identified']}")
print(f"Graph nodes: {result['debug_info']['graph_nodes_found']}")
print(f"Vector docs: {result['debug_info']['vector_docs_found']}")