Legal-assistant / process_aware_rag.py
Rivalcoder's picture
Update process_aware_rag.py
2497a16 verified
raw
history blame
10.6 kB
# process_aware_rag.py
import os
# Disable telemetry before importing chromadb
os.environ.setdefault("POSTHOG_DISABLED", "true")
os.environ.setdefault("CHROMA_TELEMETRY_DISABLED", "true")
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
import google.generativeai as genai
from query_classifier import QueryClassifier
from graph_builder import LegalProcessGraph
from typing import Dict, List, Any
class ProcessAwareRAG:
def __init__(self):
# Initialize components
self.classifier = QueryClassifier()
self.legal_graph = LegalProcessGraph()
self.legal_graph.load_graph('legal_processes.pkl')
# Initialize vector store (use writable path by default)
chroma_path = os.getenv('CHROMA_DB_PATH', '/tmp/legal_vector_db')
os.makedirs(chroma_path, exist_ok=True)
# Redirect model caches to writable directories
default_cache_root = os.getenv('CACHE_ROOT', '/data/cache')
os.environ.setdefault('HOME', '/data')
os.makedirs(default_cache_root, exist_ok=True)
os.makedirs(os.path.join(os.environ['HOME'], '.cache'), exist_ok=True)
os.makedirs(os.path.join(os.environ['HOME'], '.cache', 'chroma'), exist_ok=True)
os.environ.setdefault('HF_HOME', os.path.join(default_cache_root, 'hf'))
os.environ.setdefault('TRANSFORMERS_CACHE', os.path.join(default_cache_root, 'transformers'))
os.environ.setdefault('SENTENCE_TRANSFORMERS_HOME', os.path.join(default_cache_root, 'sentence-transformers'))
os.environ.setdefault('XDG_CACHE_HOME', default_cache_root)
for env_key in ['HF_HOME', 'TRANSFORMERS_CACHE', 'SENTENCE_TRANSFORMERS_HOME', 'XDG_CACHE_HOME']:
os.makedirs(os.environ[env_key], exist_ok=True)
# Disable Chroma anonymized telemetry and initialize client
client = chromadb.PersistentClient(
path=chroma_path,
settings=Settings(anonymized_telemetry=False)
)
# Ensure collection exists
# Use explicit embedding function to ensure queries can compute embeddings
embedding_function = embedding_functions.DefaultEmbeddingFunction()
try:
self.vector_collection = client.get_collection(
"legal_context",
embedding_function=embedding_function
)
except Exception:
self.vector_collection = client.create_collection(
"legal_context",
embedding_function=embedding_function
)
# Initialize LLM
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
self.llm = genai.GenerativeModel('gemini-2.5-flash-lite')
def retrieve_graph_context(self, classification: Dict) -> Dict:
"""Retrieve relevant context from knowledge graph"""
graph_context = {
'current_step': None,
'next_steps': [],
'resources': [],
'process_overview': None
}
if classification['process'] == 'general':
return graph_context
process_name = classification['process_name']
# Find current step or process start
if classification['step']:
current_step_id = classification['step']
else:
current_step_id = self.legal_graph.find_process_start(process_name)
if current_step_id:
# Get current step info
if current_step_id in self.legal_graph.graph.nodes:
graph_context['current_step'] = {
'id': current_step_id,
**self.legal_graph.graph.nodes[current_step_id]
}
# Get next steps
graph_context['next_steps'] = self.legal_graph.get_next_steps(current_step_id)
# Get relevant resources
graph_context['resources'] = self.legal_graph.get_required_resources(current_step_id)
return graph_context
def retrieve_vector_context(self, user_query: str, classification: Dict) -> List[Dict]:
"""Retrieve relevant context from vector store"""
# Query vector store
results = self.vector_collection.query(
query_texts=[user_query],
n_results=3,
where={"process": classification.get('process_name', '')} if classification.get('process_name') else None
)
vector_context = []
if results['documents']:
for i in range(len(results['documents'][0])):
vector_context.append({
'content': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i] if 'distances' in results else None
})
return vector_context
def generate_response(self, user_query: str, graph_context: Dict, vector_context: List[Dict], classification: Dict) -> str:
"""Generate comprehensive response using LLM"""
# Build structured prompt
system_prompt = """
You are a helpful, empathetic, and precise legal guide for Indian legal processes.
IMPORTANT GUIDELINES:
- You must NEVER give legal advice, only guide users through official processes
- Always stress that users should consult a qualified lawyer for their specific case
- Provide specific, actionable steps when possible
- Include official links, phone numbers, and government portals when available
- Be empathetic and understanding of the user's situation
- Use clear, simple language avoiding legal jargon
- Give In The User Language What User Is Communicating or Giving Questions SO Answer In That Language Accordingly
- And Must be Short and Clear Not To give Like long Long Para answers Answer must be Short and Clear
- Dont Give Large Answer Give Answer In One or two three Lines answers Give Accordingly
"""
# Prepare context information
context_sections = []
# Add graph context
if graph_context['current_step']:
context_sections.append(f"""
CURRENT PROCESS STEP:
Title: {graph_context['current_step']['title']}
Description: {graph_context['current_step']['description']}
Properties: {graph_context['current_step'].get('properties', {})}
""")
if graph_context['next_steps']:
next_steps_text = "\n".join([
f"- {step['title']}: {step['description']}"
for step in graph_context['next_steps']
])
context_sections.append(f"""
NEXT STEPS:
{next_steps_text}
""")
if graph_context['resources']:
resources_text = "\n".join([
f"- {res['title']} ({res['type']}): {res['properties'].get('url', res['properties'].get('phone', 'Contact available'))}"
for res in graph_context['resources']
])
context_sections.append(f"""
RELEVANT RESOURCES:
{resources_text}
""")
# Add vector context
if vector_context:
vector_text = "\n\n".join([doc['content'] for doc in vector_context])
context_sections.append(f"""
ADDITIONAL CONTEXT:
{vector_text}
""")
# Build final prompt
full_prompt = f"""
{system_prompt}
USER QUERY: "{user_query}"
CLASSIFICATION: Process: {classification.get('process_name', 'General')}, Intent: {classification.get('intent', 'information')}
{chr(10).join(context_sections)}
Please provide a helpful, structured response that:
1. Acknowledges the user's situation empathetically
2. Provides specific next steps if this is a process guidance request
3. Includes relevant official links and contact information
4. Reminds the user to consult a lawyer for specific legal advice
5. Uses bullet points, bold formatting, and clear structure
Format your response with clear sections and actionable information.
"""
try:
response = self.llm.generate_content(full_prompt)
return response.text
except Exception as e:
return f"I apologize, but I'm having trouble generating a response right now. Please try again or contact NALSA directly at nalsa-dla@nic.in for legal aid queries. Error: {str(e)}"
def process_query(self, user_query: str) -> Dict[str, Any]:
"""Main pipeline: process user query end-to-end"""
# Step 1: Classify query
classification = self.classifier.classify_query(user_query)
# Step 2: Retrieve graph context
graph_context = self.retrieve_graph_context(classification)
# Step 3: Retrieve vector context
vector_context = self.retrieve_vector_context(user_query, classification)
# Step 4: Generate response
response = self.generate_response(user_query, graph_context, vector_context, classification)
return {
'response': response,
'classification': classification,
'graph_context': graph_context,
'vector_context': vector_context,
'debug_info': {
'graph_nodes_found': len(graph_context.get('next_steps', [])),
'vector_docs_found': len(vector_context),
'process_identified': classification.get('process_name')
}
}
# Test the complete pipeline
if __name__ == "__main__":
rag_system = ProcessAwareRAG()
test_query = "I need free legal help but I'm not sure if I qualify. My monthly income is around 45000 rupees."
result = rag_system.process_query(test_query)
print("=== QUERY ===")
print(test_query)
print("\n=== RESPONSE ===")
print(result['response'])
print("\n=== DEBUG INFO ===")
print(f"Process: {result['debug_info']['process_identified']}")
print(f"Graph nodes: {result['debug_info']['graph_nodes_found']}")
print(f"Vector docs: {result['debug_info']['vector_docs_found']}")