Spaces:
Running
Running
| # process_aware_rag.py | |
| import chromadb | |
| from chromadb.config import Settings | |
| from chromadb.utils import embedding_functions | |
| import google.generativeai as genai | |
| from query_classifier import QueryClassifier | |
| from graph_builder import LegalProcessGraph | |
| import os | |
| from typing import Dict, List, Any | |
| class ProcessAwareRAG: | |
| def __init__(self): | |
| # Initialize components | |
| self.classifier = QueryClassifier() | |
| self.legal_graph = LegalProcessGraph() | |
| self.legal_graph.load_graph('legal_processes.pkl') | |
| # Initialize vector store (use writable path by default) | |
| chroma_path = os.getenv('CHROMA_DB_PATH', '/tmp/legal_vector_db') | |
| os.makedirs(chroma_path, exist_ok=True) | |
| # Redirect model caches to writable directories | |
| default_cache_root = os.getenv('CACHE_ROOT', '/data/cache') | |
| os.environ.setdefault('HOME', '/data') | |
| os.makedirs(default_cache_root, exist_ok=True) | |
| os.makedirs(os.path.join(os.environ['HOME'], '.cache'), exist_ok=True) | |
| os.makedirs(os.path.join(os.environ['HOME'], '.cache', 'chroma'), exist_ok=True) | |
| os.environ.setdefault('HF_HOME', os.path.join(default_cache_root, 'hf')) | |
| os.environ.setdefault('TRANSFORMERS_CACHE', os.path.join(default_cache_root, 'transformers')) | |
| os.environ.setdefault('SENTENCE_TRANSFORMERS_HOME', os.path.join(default_cache_root, 'sentence-transformers')) | |
| os.environ.setdefault('XDG_CACHE_HOME', default_cache_root) | |
| for env_key in ['HF_HOME', 'TRANSFORMERS_CACHE', 'SENTENCE_TRANSFORMERS_HOME', 'XDG_CACHE_HOME']: | |
| os.makedirs(os.environ[env_key], exist_ok=True) | |
| # Disable Chroma anonymized telemetry and initialize client | |
| client = chromadb.PersistentClient( | |
| path=chroma_path, | |
| settings=Settings(anonymized_telemetry=False) | |
| ) | |
| # Ensure collection exists | |
| # Use explicit embedding function to ensure queries can compute embeddings | |
| embedding_function = embedding_functions.DefaultEmbeddingFunction() | |
| try: | |
| self.vector_collection = client.get_collection( | |
| "legal_context", | |
| embedding_function=embedding_function | |
| ) | |
| except Exception: | |
| self.vector_collection = client.create_collection( | |
| "legal_context", | |
| embedding_function=embedding_function | |
| ) | |
| # Initialize LLM | |
| genai.configure(api_key=os.getenv('GOOGLE_API_KEY')) | |
| self.llm = genai.GenerativeModel('gemini-1.5-flash') | |
| def retrieve_graph_context(self, classification: Dict) -> Dict: | |
| """Retrieve relevant context from knowledge graph""" | |
| graph_context = { | |
| 'current_step': None, | |
| 'next_steps': [], | |
| 'resources': [], | |
| 'process_overview': None | |
| } | |
| if classification['process'] == 'general': | |
| return graph_context | |
| process_name = classification['process_name'] | |
| # Find current step or process start | |
| if classification['step']: | |
| current_step_id = classification['step'] | |
| else: | |
| current_step_id = self.legal_graph.find_process_start(process_name) | |
| if current_step_id: | |
| # Get current step info | |
| if current_step_id in self.legal_graph.graph.nodes: | |
| graph_context['current_step'] = { | |
| 'id': current_step_id, | |
| **self.legal_graph.graph.nodes[current_step_id] | |
| } | |
| # Get next steps | |
| graph_context['next_steps'] = self.legal_graph.get_next_steps(current_step_id) | |
| # Get relevant resources | |
| graph_context['resources'] = self.legal_graph.get_required_resources(current_step_id) | |
| return graph_context | |
| def retrieve_vector_context(self, user_query: str, classification: Dict) -> List[Dict]: | |
| """Retrieve relevant context from vector store""" | |
| # Query vector store | |
| results = self.vector_collection.query( | |
| query_texts=[user_query], | |
| n_results=3, | |
| where={"process": classification.get('process_name', '')} if classification.get('process_name') else None | |
| ) | |
| vector_context = [] | |
| if results['documents']: | |
| for i in range(len(results['documents'][0])): | |
| vector_context.append({ | |
| 'content': results['documents'][0][i], | |
| 'metadata': results['metadatas'][0][i], | |
| 'distance': results['distances'][0][i] if 'distances' in results else None | |
| }) | |
| return vector_context | |
| def generate_response(self, user_query: str, graph_context: Dict, vector_context: List[Dict], classification: Dict) -> str: | |
| """Generate comprehensive response using LLM""" | |
| # Build structured prompt | |
| system_prompt = """ | |
| You are a helpful, empathetic, and precise legal guide for Indian legal processes. | |
| IMPORTANT GUIDELINES: | |
| - You must NEVER give legal advice, only guide users through official processes | |
| - Always stress that users should consult a qualified lawyer for their specific case | |
| - Provide specific, actionable steps when possible | |
| - Include official links, phone numbers, and government portals when available | |
| - Be empathetic and understanding of the user's situation | |
| - Use clear, simple language avoiding legal jargon | |
| """ | |
| # Prepare context information | |
| context_sections = [] | |
| # Add graph context | |
| if graph_context['current_step']: | |
| context_sections.append(f""" | |
| CURRENT PROCESS STEP: | |
| Title: {graph_context['current_step']['title']} | |
| Description: {graph_context['current_step']['description']} | |
| Properties: {graph_context['current_step'].get('properties', {})} | |
| """) | |
| if graph_context['next_steps']: | |
| next_steps_text = "\n".join([ | |
| f"- {step['title']}: {step['description']}" | |
| for step in graph_context['next_steps'] | |
| ]) | |
| context_sections.append(f""" | |
| NEXT STEPS: | |
| {next_steps_text} | |
| """) | |
| if graph_context['resources']: | |
| resources_text = "\n".join([ | |
| f"- {res['title']} ({res['type']}): {res['properties'].get('url', res['properties'].get('phone', 'Contact available'))}" | |
| for res in graph_context['resources'] | |
| ]) | |
| context_sections.append(f""" | |
| RELEVANT RESOURCES: | |
| {resources_text} | |
| """) | |
| # Add vector context | |
| if vector_context: | |
| vector_text = "\n\n".join([doc['content'] for doc in vector_context]) | |
| context_sections.append(f""" | |
| ADDITIONAL CONTEXT: | |
| {vector_text} | |
| """) | |
| # Build final prompt | |
| full_prompt = f""" | |
| {system_prompt} | |
| USER QUERY: "{user_query}" | |
| CLASSIFICATION: Process: {classification.get('process_name', 'General')}, Intent: {classification.get('intent', 'information')} | |
| {chr(10).join(context_sections)} | |
| Please provide a helpful, structured response that: | |
| 1. Acknowledges the user's situation empathetically | |
| 2. Provides specific next steps if this is a process guidance request | |
| 3. Includes relevant official links and contact information | |
| 4. Reminds the user to consult a lawyer for specific legal advice | |
| 5. Uses bullet points, bold formatting, and clear structure | |
| Format your response with clear sections and actionable information. | |
| """ | |
| try: | |
| response = self.llm.generate_content(full_prompt) | |
| return response.text | |
| except Exception as e: | |
| return f"I apologize, but I'm having trouble generating a response right now. Please try again or contact NALSA directly at nalsa-dla@nic.in for legal aid queries. Error: {str(e)}" | |
| def process_query(self, user_query: str) -> Dict[str, Any]: | |
| """Main pipeline: process user query end-to-end""" | |
| # Step 1: Classify query | |
| classification = self.classifier.classify_query(user_query) | |
| # Step 2: Retrieve graph context | |
| graph_context = self.retrieve_graph_context(classification) | |
| # Step 3: Retrieve vector context | |
| vector_context = self.retrieve_vector_context(user_query, classification) | |
| # Step 4: Generate response | |
| response = self.generate_response(user_query, graph_context, vector_context, classification) | |
| return { | |
| 'response': response, | |
| 'classification': classification, | |
| 'graph_context': graph_context, | |
| 'vector_context': vector_context, | |
| 'debug_info': { | |
| 'graph_nodes_found': len(graph_context.get('next_steps', [])), | |
| 'vector_docs_found': len(vector_context), | |
| 'process_identified': classification.get('process_name') | |
| } | |
| } | |
| # Test the complete pipeline | |
| if __name__ == "__main__": | |
| rag_system = ProcessAwareRAG() | |
| test_query = "I need free legal help but I'm not sure if I qualify. My monthly income is around 45000 rupees." | |
| result = rag_system.process_query(test_query) | |
| print("=== QUERY ===") | |
| print(test_query) | |
| print("\n=== RESPONSE ===") | |
| print(result['response']) | |
| print("\n=== DEBUG INFO ===") | |
| print(f"Process: {result['debug_info']['process_identified']}") | |
| print(f"Graph nodes: {result['debug_info']['graph_nodes_found']}") | |
| print(f"Vector docs: {result['debug_info']['vector_docs_found']}") | |