Spaces:
Running
Running
| import os | |
| # Disable telemetry before importing chromadb | |
| os.environ.setdefault("POSTHOG_DISABLED", "true") | |
| os.environ.setdefault("CHROMA_TELEMETRY_DISABLED", "true") | |
| import chromadb | |
| from chromadb.config import Settings | |
| from chromadb.utils import embedding_functions | |
| import google.generativeai as genai | |
| from query_classifier import QueryClassifier | |
| from graph_builder import LegalProcessGraph | |
| from typing import Dict, List, Any | |
| class ProcessAwareRAG: | |
| def __init__(self): | |
| # Initialize components | |
| self.classifier = QueryClassifier() | |
| self.legal_graph = LegalProcessGraph() | |
| self.legal_graph.load_graph('legal_processes.pkl') | |
| # Initialize vector store | |
| chroma_path = os.getenv('CHROMA_DB_PATH', '/tmp/legal_vector_db') | |
| os.makedirs(chroma_path, exist_ok=True) | |
| # Redirect model caches | |
| default_cache_root = os.getenv('CACHE_ROOT', '/data/cache') | |
| os.environ.setdefault('HOME', '/data') | |
| os.makedirs(default_cache_root, exist_ok=True) | |
| for env_key in [ | |
| 'HF_HOME', 'TRANSFORMERS_CACHE', | |
| 'SENTENCE_TRANSFORMERS_HOME', 'XDG_CACHE_HOME' | |
| ]: | |
| os.environ.setdefault(env_key, os.path.join(default_cache_root, env_key.lower())) | |
| os.makedirs(os.environ[env_key], exist_ok=True) | |
| # Disable Chroma telemetry & init client | |
| client = chromadb.PersistentClient( | |
| path=chroma_path, | |
| settings=Settings(anonymized_telemetry=False) | |
| ) | |
| # Ensure collection exists | |
| embedding_function = embedding_functions.DefaultEmbeddingFunction() | |
| try: | |
| self.vector_collection = client.get_collection( | |
| "legal_context", | |
| embedding_function=embedding_function | |
| ) | |
| except Exception: | |
| self.vector_collection = client.create_collection( | |
| "legal_context", | |
| embedding_function=embedding_function | |
| ) | |
| # Init LLM | |
| genai.configure(api_key=os.getenv('GOOGLE_API_KEY')) | |
| self.llm = genai.GenerativeModel('gemini-2.5-flash-lite') | |
| def retrieve_graph_context(self, classification: Dict) -> Dict: | |
| graph_context = { | |
| 'current_step': None, | |
| 'next_steps': [], | |
| 'resources': [], | |
| 'process_overview': None | |
| } | |
| if classification['process'] == 'general': | |
| return graph_context | |
| process_name = classification['process_name'] | |
| current_step_id = classification['step'] or self.legal_graph.find_process_start(process_name) | |
| if current_step_id and current_step_id in self.legal_graph.graph.nodes: | |
| graph_context['current_step'] = { | |
| 'id': current_step_id, | |
| **self.legal_graph.graph.nodes[current_step_id] | |
| } | |
| graph_context['next_steps'] = self.legal_graph.get_next_steps(current_step_id) | |
| graph_context['resources'] = self.legal_graph.get_required_resources(current_step_id) | |
| return graph_context | |
| def retrieve_vector_context(self, user_query: str, classification: Dict) -> List[Dict]: | |
| results = self.vector_collection.query( | |
| query_texts=[user_query], | |
| n_results=3, | |
| where={"process": classification.get('process_name', '')} if classification.get('process_name') else None | |
| ) | |
| vector_context = [] | |
| if results['documents']: | |
| for i in range(len(results['documents'][0])): | |
| vector_context.append({ | |
| 'content': results['documents'][0][i], | |
| 'metadata': results['metadatas'][0][i], | |
| 'distance': results['distances'][0][i] if 'distances' in results else None | |
| }) | |
| return vector_context | |
| def generate_response(self, user_query: str, history: List[Dict], | |
| graph_context: Dict, vector_context: List[Dict], | |
| classification: Dict) -> str: | |
| """ | |
| Generate Indian Law Assistant responses: | |
| - Normal legal Q&A: Short, clear answers | |
| - Complex / case-specific: Say consult professional | |
| - If vector data is wrong: Override with correct info | |
| """ | |
| system_prompt = """ | |
| You are an **Indian Law Assistant**. | |
| RULES: | |
| - For general questions on Indian law (IPC, CrPC, FIR, consumer law, bail, etc.), answer briefly (1–3 lines). | |
| - Use **sections of Indian laws (IPC, CrPC, Evidence Act, etc.)** where relevant. | |
| - If query is very **complex / case-dependent**, politely suggest consulting a qualified lawyer, but still share general process info. | |
| - If retrieved data looks irrelevant or wrong, override it with correct general legal knowledge of Indian laws. | |
| - Be empathetic, clear, and concise. | |
| - Use the same language as the user (Hindi/English mix if needed). | |
| - Avoid long paragraphs. Use short, crisp answers with bullet points where possible. | |
| """ | |
| # Format chat history (use last 5 messages for context) | |
| history_text = "\n".join([f"{m['role']}: {m['content']}" for m in history[-5:]]) if history else "" | |
| # Build context | |
| context_sections = [] | |
| if graph_context['current_step']: | |
| context_sections.append(f""" | |
| CURRENT STEP: | |
| {graph_context['current_step']['title']} - {graph_context['current_step']['description']} | |
| """) | |
| if graph_context['next_steps']: | |
| context_sections.append("NEXT STEPS:\n" + "\n".join([ | |
| f"- {s['title']}: {s['description']}" for s in graph_context['next_steps'] | |
| ])) | |
| if graph_context['resources']: | |
| context_sections.append("RESOURCES:\n" + "\n".join([ | |
| f"- {r['title']} ({r['type']}): {r['properties'].get('url', r['properties'].get('phone', 'Contact available'))}" | |
| for r in graph_context['resources'] | |
| ])) | |
| if vector_context: | |
| vector_text = "\n\n".join([doc['content'] for doc in vector_context]) | |
| context_sections.append("REFERENCE CONTEXT:\n" + vector_text) | |
| full_prompt = f""" | |
| {system_prompt} | |
| CHAT HISTORY: | |
| {history_text} | |
| USER QUERY: "{user_query}" | |
| CLASSIFICATION: Process: {classification.get('process_name', 'General')} | Intent: {classification.get('intent', 'information')} | |
| CONTEXT: | |
| {chr(10).join(context_sections)} | |
| Generate the best possible short, accurate, and user-friendly response. | |
| """ | |
| try: | |
| response = self.llm.generate_content(full_prompt) | |
| return response.text | |
| except Exception as e: | |
| return f"⚠️ System error. Please try again later or contact NALSA (nalsa-dla@nic.in). Error: {str(e)}" | |
| def process_query(self, user_query: str, history: List[Dict]) -> Dict[str, Any]: | |
| classification = self.classifier.classify_query(user_query) | |
| graph_context = self.retrieve_graph_context(classification) | |
| vector_context = self.retrieve_vector_context(user_query, classification) | |
| response = self.generate_response(user_query, history, graph_context, vector_context, classification) | |
| return { | |
| 'response': response, | |
| 'classification': classification, | |
| 'graph_context': graph_context, | |
| 'vector_context': vector_context, | |
| 'debug_info': { | |
| 'graph_nodes_found': len(graph_context.get('next_steps', [])), | |
| 'vector_docs_found': len(vector_context), | |
| 'process_identified': classification.get('process_name'), | |
| 'history_length': len(history) | |
| } | |
| } | |
| if __name__ == "__main__": | |
| rag_system = ProcessAwareRAG() | |
| test_query = "Under which IPC section is cheating punishable in India?" | |
| result = rag_system.process_query(test_query, history=[]) | |
| print("=== QUERY ===") | |
| print(test_query) | |
| print("\n=== RESPONSE ===") | |
| print(result['response']) | |
| print("\n=== DEBUG INFO ===") | |
| print(result['debug_info']) | |