import os # Disable telemetry before importing chromadb os.environ.setdefault("POSTHOG_DISABLED", "true") os.environ.setdefault("CHROMA_TELEMETRY_DISABLED", "true") import chromadb from chromadb.config import Settings from chromadb.utils import embedding_functions import google.generativeai as genai from query_classifier import QueryClassifier from graph_builder import LegalProcessGraph from typing import Dict, List, Any class ProcessAwareRAG: def __init__(self): # Initialize components self.classifier = QueryClassifier() self.legal_graph = LegalProcessGraph() self.legal_graph.load_graph('legal_processes.pkl') # Initialize vector store chroma_path = os.getenv('CHROMA_DB_PATH', '/tmp/legal_vector_db') os.makedirs(chroma_path, exist_ok=True) # Redirect model caches default_cache_root = os.getenv('CACHE_ROOT', '/data/cache') os.environ.setdefault('HOME', '/data') os.makedirs(default_cache_root, exist_ok=True) for env_key in [ 'HF_HOME', 'TRANSFORMERS_CACHE', 'SENTENCE_TRANSFORMERS_HOME', 'XDG_CACHE_HOME' ]: os.environ.setdefault(env_key, os.path.join(default_cache_root, env_key.lower())) os.makedirs(os.environ[env_key], exist_ok=True) # Disable Chroma telemetry & init client client = chromadb.PersistentClient( path=chroma_path, settings=Settings(anonymized_telemetry=False) ) # Ensure collection exists embedding_function = embedding_functions.DefaultEmbeddingFunction() try: self.vector_collection = client.get_collection( "legal_context", embedding_function=embedding_function ) except Exception: self.vector_collection = client.create_collection( "legal_context", embedding_function=embedding_function ) # Init LLM genai.configure(api_key=os.getenv('GOOGLE_API_KEY')) self.llm = genai.GenerativeModel('gemini-2.5-flash-lite') def retrieve_graph_context(self, classification: Dict) -> Dict: graph_context = { 'current_step': None, 'next_steps': [], 'resources': [], 'process_overview': None } if classification['process'] == 'general': return graph_context process_name = classification['process_name'] current_step_id = classification['step'] or self.legal_graph.find_process_start(process_name) if current_step_id and current_step_id in self.legal_graph.graph.nodes: graph_context['current_step'] = { 'id': current_step_id, **self.legal_graph.graph.nodes[current_step_id] } graph_context['next_steps'] = self.legal_graph.get_next_steps(current_step_id) graph_context['resources'] = self.legal_graph.get_required_resources(current_step_id) return graph_context def retrieve_vector_context(self, user_query: str, classification: Dict) -> List[Dict]: results = self.vector_collection.query( query_texts=[user_query], n_results=3, where={"process": classification.get('process_name', '')} if classification.get('process_name') else None ) vector_context = [] if results['documents']: for i in range(len(results['documents'][0])): vector_context.append({ 'content': results['documents'][0][i], 'metadata': results['metadatas'][0][i], 'distance': results['distances'][0][i] if 'distances' in results else None }) return vector_context def generate_response(self, user_query: str, history: List[Dict], graph_context: Dict, vector_context: List[Dict], classification: Dict) -> str: """ Generate Indian Law Assistant responses: - Normal legal Q&A: Short, clear answers - Complex / case-specific: Say consult professional - If vector data is wrong: Override with correct info """ system_prompt = """ You are an **Indian Law Assistant**. RULES: - For general questions on Indian law (IPC, CrPC, FIR, consumer law, bail, etc.), answer briefly (1–3 lines). - Use **sections of Indian laws (IPC, CrPC, Evidence Act, etc.)** where relevant. - If query is very **complex / case-dependent**, politely suggest consulting a qualified lawyer, but still share general process info. - If retrieved data looks irrelevant or wrong, override it with correct general legal knowledge of Indian laws. - Be empathetic, clear, and concise. - Use the same language as the user (Hindi/English mix if needed). - Avoid long paragraphs. Use short, crisp answers with bullet points where possible. """ # Format chat history (use last 5 messages for context) history_text = "\n".join([f"{m['role']}: {m['content']}" for m in history[-5:]]) if history else "" # Build context context_sections = [] if graph_context['current_step']: context_sections.append(f""" CURRENT STEP: {graph_context['current_step']['title']} - {graph_context['current_step']['description']} """) if graph_context['next_steps']: context_sections.append("NEXT STEPS:\n" + "\n".join([ f"- {s['title']}: {s['description']}" for s in graph_context['next_steps'] ])) if graph_context['resources']: context_sections.append("RESOURCES:\n" + "\n".join([ f"- {r['title']} ({r['type']}): {r['properties'].get('url', r['properties'].get('phone', 'Contact available'))}" for r in graph_context['resources'] ])) if vector_context: vector_text = "\n\n".join([doc['content'] for doc in vector_context]) context_sections.append("REFERENCE CONTEXT:\n" + vector_text) full_prompt = f""" {system_prompt} CHAT HISTORY: {history_text} USER QUERY: "{user_query}" CLASSIFICATION: Process: {classification.get('process_name', 'General')} | Intent: {classification.get('intent', 'information')} CONTEXT: {chr(10).join(context_sections)} Generate the best possible short, accurate, and user-friendly response. """ try: response = self.llm.generate_content(full_prompt) return response.text except Exception as e: return f"⚠️ System error. Please try again later or contact NALSA (nalsa-dla@nic.in). Error: {str(e)}" def process_query(self, user_query: str, history: List[Dict]) -> Dict[str, Any]: classification = self.classifier.classify_query(user_query) graph_context = self.retrieve_graph_context(classification) vector_context = self.retrieve_vector_context(user_query, classification) response = self.generate_response(user_query, history, graph_context, vector_context, classification) return { 'response': response, 'classification': classification, 'graph_context': graph_context, 'vector_context': vector_context, 'debug_info': { 'graph_nodes_found': len(graph_context.get('next_steps', [])), 'vector_docs_found': len(vector_context), 'process_identified': classification.get('process_name'), 'history_length': len(history) } } if __name__ == "__main__": rag_system = ProcessAwareRAG() test_query = "Under which IPC section is cheating punishable in India?" result = rag_system.process_query(test_query, history=[]) print("=== QUERY ===") print(test_query) print("\n=== RESPONSE ===") print(result['response']) print("\n=== DEBUG INFO ===") print(result['debug_info'])