Legal-assistant / process_aware_rag.py
Rivalcoder's picture
Update process_aware_rag.py
737bc7d verified
import os
# Disable telemetry before importing chromadb
os.environ.setdefault("POSTHOG_DISABLED", "true")
os.environ.setdefault("CHROMA_TELEMETRY_DISABLED", "true")
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
import google.generativeai as genai
from query_classifier import QueryClassifier
from graph_builder import LegalProcessGraph
from typing import Dict, List, Any
class ProcessAwareRAG:
def __init__(self):
# Initialize components
self.classifier = QueryClassifier()
self.legal_graph = LegalProcessGraph()
self.legal_graph.load_graph('legal_processes.pkl')
# Initialize vector store
chroma_path = os.getenv('CHROMA_DB_PATH', '/tmp/legal_vector_db')
os.makedirs(chroma_path, exist_ok=True)
# Redirect model caches
default_cache_root = os.getenv('CACHE_ROOT', '/data/cache')
os.environ.setdefault('HOME', '/data')
os.makedirs(default_cache_root, exist_ok=True)
for env_key in [
'HF_HOME', 'TRANSFORMERS_CACHE',
'SENTENCE_TRANSFORMERS_HOME', 'XDG_CACHE_HOME'
]:
os.environ.setdefault(env_key, os.path.join(default_cache_root, env_key.lower()))
os.makedirs(os.environ[env_key], exist_ok=True)
# Disable Chroma telemetry & init client
client = chromadb.PersistentClient(
path=chroma_path,
settings=Settings(anonymized_telemetry=False)
)
# Ensure collection exists
embedding_function = embedding_functions.DefaultEmbeddingFunction()
try:
self.vector_collection = client.get_collection(
"legal_context",
embedding_function=embedding_function
)
except Exception:
self.vector_collection = client.create_collection(
"legal_context",
embedding_function=embedding_function
)
# Init LLM
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
self.llm = genai.GenerativeModel('gemini-2.5-flash-lite')
def retrieve_graph_context(self, classification: Dict) -> Dict:
graph_context = {
'current_step': None,
'next_steps': [],
'resources': [],
'process_overview': None
}
if classification['process'] == 'general':
return graph_context
process_name = classification['process_name']
current_step_id = classification['step'] or self.legal_graph.find_process_start(process_name)
if current_step_id and current_step_id in self.legal_graph.graph.nodes:
graph_context['current_step'] = {
'id': current_step_id,
**self.legal_graph.graph.nodes[current_step_id]
}
graph_context['next_steps'] = self.legal_graph.get_next_steps(current_step_id)
graph_context['resources'] = self.legal_graph.get_required_resources(current_step_id)
return graph_context
def retrieve_vector_context(self, user_query: str, classification: Dict) -> List[Dict]:
results = self.vector_collection.query(
query_texts=[user_query],
n_results=3,
where={"process": classification.get('process_name', '')} if classification.get('process_name') else None
)
vector_context = []
if results['documents']:
for i in range(len(results['documents'][0])):
vector_context.append({
'content': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i] if 'distances' in results else None
})
return vector_context
def generate_response(self, user_query: str, history: List[Dict],
graph_context: Dict, vector_context: List[Dict],
classification: Dict) -> str:
"""
Generate Indian Law Assistant responses:
- Normal legal Q&A: Short, clear answers
- Complex / case-specific: Say consult professional
- If vector data is wrong: Override with correct info
"""
system_prompt = """
You are an **Indian Law Assistant**.
RULES:
- For general questions on Indian law (IPC, CrPC, FIR, consumer law, bail, etc.), answer briefly (1–3 lines).
- Use **sections of Indian laws (IPC, CrPC, Evidence Act, etc.)** where relevant.
- If query is very **complex / case-dependent**, politely suggest consulting a qualified lawyer, but still share general process info.
- If retrieved data looks irrelevant or wrong, override it with correct general legal knowledge of Indian laws.
- Be empathetic, clear, and concise.
- Use the same language as the user (Hindi/English mix if needed).
- Avoid long paragraphs. Use short, crisp answers with bullet points where possible.
"""
# Format chat history (use last 5 messages for context)
history_text = "\n".join([f"{m['role']}: {m['content']}" for m in history[-5:]]) if history else ""
# Build context
context_sections = []
if graph_context['current_step']:
context_sections.append(f"""
CURRENT STEP:
{graph_context['current_step']['title']} - {graph_context['current_step']['description']}
""")
if graph_context['next_steps']:
context_sections.append("NEXT STEPS:\n" + "\n".join([
f"- {s['title']}: {s['description']}" for s in graph_context['next_steps']
]))
if graph_context['resources']:
context_sections.append("RESOURCES:\n" + "\n".join([
f"- {r['title']} ({r['type']}): {r['properties'].get('url', r['properties'].get('phone', 'Contact available'))}"
for r in graph_context['resources']
]))
if vector_context:
vector_text = "\n\n".join([doc['content'] for doc in vector_context])
context_sections.append("REFERENCE CONTEXT:\n" + vector_text)
full_prompt = f"""
{system_prompt}
CHAT HISTORY:
{history_text}
USER QUERY: "{user_query}"
CLASSIFICATION: Process: {classification.get('process_name', 'General')} | Intent: {classification.get('intent', 'information')}
CONTEXT:
{chr(10).join(context_sections)}
Generate the best possible short, accurate, and user-friendly response.
"""
try:
response = self.llm.generate_content(full_prompt)
return response.text
except Exception as e:
return f"⚠️ System error. Please try again later or contact NALSA (nalsa-dla@nic.in). Error: {str(e)}"
def process_query(self, user_query: str, history: List[Dict]) -> Dict[str, Any]:
classification = self.classifier.classify_query(user_query)
graph_context = self.retrieve_graph_context(classification)
vector_context = self.retrieve_vector_context(user_query, classification)
response = self.generate_response(user_query, history, graph_context, vector_context, classification)
return {
'response': response,
'classification': classification,
'graph_context': graph_context,
'vector_context': vector_context,
'debug_info': {
'graph_nodes_found': len(graph_context.get('next_steps', [])),
'vector_docs_found': len(vector_context),
'process_identified': classification.get('process_name'),
'history_length': len(history)
}
}
if __name__ == "__main__":
rag_system = ProcessAwareRAG()
test_query = "Under which IPC section is cheating punishable in India?"
result = rag_system.process_query(test_query, history=[])
print("=== QUERY ===")
print(test_query)
print("\n=== RESPONSE ===")
print(result['response'])
print("\n=== DEBUG INFO ===")
print(result['debug_info'])