Rivalcoder commited on
Commit
074a13b
·
verified ·
1 Parent(s): 1a1eddd

Update process_aware_rag.py

Browse files
Files changed (1) hide show
  1. process_aware_rag.py +72 -131
process_aware_rag.py CHANGED
@@ -20,31 +20,28 @@ class ProcessAwareRAG:
20
  self.legal_graph = LegalProcessGraph()
21
  self.legal_graph.load_graph('legal_processes.pkl')
22
 
23
- # Initialize vector store (use writable path by default)
24
  chroma_path = os.getenv('CHROMA_DB_PATH', '/tmp/legal_vector_db')
25
  os.makedirs(chroma_path, exist_ok=True)
26
 
27
- # Redirect model caches to writable directories
28
  default_cache_root = os.getenv('CACHE_ROOT', '/data/cache')
29
  os.environ.setdefault('HOME', '/data')
30
  os.makedirs(default_cache_root, exist_ok=True)
31
- os.makedirs(os.path.join(os.environ['HOME'], '.cache'), exist_ok=True)
32
- os.makedirs(os.path.join(os.environ['HOME'], '.cache', 'chroma'), exist_ok=True)
33
- os.environ.setdefault('HF_HOME', os.path.join(default_cache_root, 'hf'))
34
- os.environ.setdefault('TRANSFORMERS_CACHE', os.path.join(default_cache_root, 'transformers'))
35
- os.environ.setdefault('SENTENCE_TRANSFORMERS_HOME', os.path.join(default_cache_root, 'sentence-transformers'))
36
- os.environ.setdefault('XDG_CACHE_HOME', default_cache_root)
37
- for env_key in ['HF_HOME', 'TRANSFORMERS_CACHE', 'SENTENCE_TRANSFORMERS_HOME', 'XDG_CACHE_HOME']:
38
  os.makedirs(os.environ[env_key], exist_ok=True)
39
 
40
- # Disable Chroma anonymized telemetry and initialize client
41
  client = chromadb.PersistentClient(
42
  path=chroma_path,
43
  settings=Settings(anonymized_telemetry=False)
44
  )
45
-
46
  # Ensure collection exists
47
- # Use explicit embedding function to ensure queries can compute embeddings
48
  embedding_function = embedding_functions.DefaultEmbeddingFunction()
49
  try:
50
  self.vector_collection = client.get_collection(
@@ -56,57 +53,40 @@ class ProcessAwareRAG:
56
  "legal_context",
57
  embedding_function=embedding_function
58
  )
59
-
60
- # Initialize LLM
61
  genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
62
  self.llm = genai.GenerativeModel('gemini-2.5-flash-lite')
63
-
64
  def retrieve_graph_context(self, classification: Dict) -> Dict:
65
- """Retrieve relevant context from knowledge graph"""
66
  graph_context = {
67
  'current_step': None,
68
  'next_steps': [],
69
  'resources': [],
70
  'process_overview': None
71
  }
72
-
73
  if classification['process'] == 'general':
74
  return graph_context
75
-
76
  process_name = classification['process_name']
77
-
78
- # Find current step or process start
79
- if classification['step']:
80
- current_step_id = classification['step']
81
- else:
82
- current_step_id = self.legal_graph.find_process_start(process_name)
83
-
84
- if current_step_id:
85
- # Get current step info
86
- if current_step_id in self.legal_graph.graph.nodes:
87
- graph_context['current_step'] = {
88
- 'id': current_step_id,
89
- **self.legal_graph.graph.nodes[current_step_id]
90
- }
91
-
92
- # Get next steps
93
- graph_context['next_steps'] = self.legal_graph.get_next_steps(current_step_id)
94
-
95
- # Get relevant resources
96
- graph_context['resources'] = self.legal_graph.get_required_resources(current_step_id)
97
-
98
  return graph_context
99
-
100
  def retrieve_vector_context(self, user_query: str, classification: Dict) -> List[Dict]:
101
- """Retrieve relevant context from vector store"""
102
-
103
- # Query vector store
104
  results = self.vector_collection.query(
105
  query_texts=[user_query],
106
  n_results=3,
107
  where={"process": classification.get('process_name', '')} if classification.get('process_name') else None
108
  )
109
-
110
  vector_context = []
111
  if results['documents']:
112
  for i in range(len(results['documents'][0])):
@@ -115,109 +95,74 @@ class ProcessAwareRAG:
115
  'metadata': results['metadatas'][0][i],
116
  'distance': results['distances'][0][i] if 'distances' in results else None
117
  })
118
-
119
  return vector_context
120
-
121
- def generate_response(self, user_query: str, graph_context: Dict, vector_context: List[Dict], classification: Dict) -> str:
122
- """Generate comprehensive response using LLM"""
123
-
124
- # Build structured prompt
 
 
 
 
 
125
  system_prompt = """
126
- You are a helpful, empathetic, and precise legal guide for Indian legal processes.
127
-
128
- IMPORTANT GUIDELINES:
129
- - You must NEVER give legal advice, only guide users through official processes
130
- - Always stress that users should consult a qualified lawyer for their specific case
131
- - Provide specific, actionable steps when possible
132
- - Include official links, phone numbers, and government portals when available
133
- - Be empathetic and understanding of the user's situation
134
- - Use clear, simple language avoiding legal jargon
135
- - Give In The User Language What User Is Communicating or Giving Questions SO Answer In That Language Accordingly
136
- - And Must be Short and Clear Not To give Like long Long Para answers Answer must be Short and Clear
137
- - Dont Give Large Answer Give Answer In One or two three Lines answers Give Accordingly
138
  """
139
-
140
- # Prepare context information
141
  context_sections = []
142
-
143
- # Add graph context
144
  if graph_context['current_step']:
145
  context_sections.append(f"""
146
- CURRENT PROCESS STEP:
147
- Title: {graph_context['current_step']['title']}
148
- Description: {graph_context['current_step']['description']}
149
- Properties: {graph_context['current_step'].get('properties', {})}
150
  """)
151
-
152
  if graph_context['next_steps']:
153
- next_steps_text = "\n".join([
154
- f"- {step['title']}: {step['description']}"
155
- for step in graph_context['next_steps']
156
- ])
157
- context_sections.append(f"""
158
- NEXT STEPS:
159
- {next_steps_text}
160
- """)
161
-
162
  if graph_context['resources']:
163
- resources_text = "\n".join([
164
- f"- {res['title']} ({res['type']}): {res['properties'].get('url', res['properties'].get('phone', 'Contact available'))}"
165
- for res in graph_context['resources']
166
- ])
167
- context_sections.append(f"""
168
- RELEVANT RESOURCES:
169
- {resources_text}
170
- """)
171
-
172
- # Add vector context
173
  if vector_context:
174
  vector_text = "\n\n".join([doc['content'] for doc in vector_context])
175
- context_sections.append(f"""
176
- ADDITIONAL CONTEXT:
177
- {vector_text}
178
- """)
179
-
180
- # Build final prompt
181
  full_prompt = f"""
182
  {system_prompt}
183
-
184
  USER QUERY: "{user_query}"
185
-
186
- CLASSIFICATION: Process: {classification.get('process_name', 'General')}, Intent: {classification.get('intent', 'information')}
187
-
188
  {chr(10).join(context_sections)}
189
-
190
- Please provide a helpful, structured response that:
191
- 1. Acknowledges the user's situation empathetically
192
- 2. Provides specific next steps if this is a process guidance request
193
- 3. Includes relevant official links and contact information
194
- 4. Reminds the user to consult a lawyer for specific legal advice
195
- 5. Uses bullet points, bold formatting, and clear structure
196
-
197
- Format your response with clear sections and actionable information.
198
  """
199
-
200
  try:
201
  response = self.llm.generate_content(full_prompt)
202
  return response.text
203
  except Exception as e:
204
- return f"I apologize, but I'm having trouble generating a response right now. Please try again or contact NALSA directly at nalsa-dla@nic.in for legal aid queries. Error: {str(e)}"
205
-
206
  def process_query(self, user_query: str) -> Dict[str, Any]:
207
- """Main pipeline: process user query end-to-end"""
208
-
209
- # Step 1: Classify query
210
  classification = self.classifier.classify_query(user_query)
211
-
212
- # Step 2: Retrieve graph context
213
  graph_context = self.retrieve_graph_context(classification)
214
-
215
- # Step 3: Retrieve vector context
216
  vector_context = self.retrieve_vector_context(user_query, classification)
217
-
218
- # Step 4: Generate response
219
  response = self.generate_response(user_query, graph_context, vector_context, classification)
220
-
221
  return {
222
  'response': response,
223
  'classification': classification,
@@ -230,19 +175,15 @@ class ProcessAwareRAG:
230
  }
231
  }
232
 
233
- # Test the complete pipeline
234
  if __name__ == "__main__":
235
  rag_system = ProcessAwareRAG()
236
-
237
- test_query = "I need free legal help but I'm not sure if I qualify. My monthly income is around 45000 rupees."
238
-
239
  result = rag_system.process_query(test_query)
240
-
241
  print("=== QUERY ===")
242
  print(test_query)
243
- print("\n=== RESPONSE ===")
244
  print(result['response'])
245
  print("\n=== DEBUG INFO ===")
246
- print(f"Process: {result['debug_info']['process_identified']}")
247
- print(f"Graph nodes: {result['debug_info']['graph_nodes_found']}")
248
- print(f"Vector docs: {result['debug_info']['vector_docs_found']}")
 
20
  self.legal_graph = LegalProcessGraph()
21
  self.legal_graph.load_graph('legal_processes.pkl')
22
 
23
+ # Initialize vector store
24
  chroma_path = os.getenv('CHROMA_DB_PATH', '/tmp/legal_vector_db')
25
  os.makedirs(chroma_path, exist_ok=True)
26
 
27
+ # Redirect model caches
28
  default_cache_root = os.getenv('CACHE_ROOT', '/data/cache')
29
  os.environ.setdefault('HOME', '/data')
30
  os.makedirs(default_cache_root, exist_ok=True)
31
+ for env_key in [
32
+ 'HF_HOME', 'TRANSFORMERS_CACHE',
33
+ 'SENTENCE_TRANSFORMERS_HOME', 'XDG_CACHE_HOME'
34
+ ]:
35
+ os.environ.setdefault(env_key, os.path.join(default_cache_root, env_key.lower()))
 
 
36
  os.makedirs(os.environ[env_key], exist_ok=True)
37
 
38
+ # Disable Chroma telemetry & init client
39
  client = chromadb.PersistentClient(
40
  path=chroma_path,
41
  settings=Settings(anonymized_telemetry=False)
42
  )
43
+
44
  # Ensure collection exists
 
45
  embedding_function = embedding_functions.DefaultEmbeddingFunction()
46
  try:
47
  self.vector_collection = client.get_collection(
 
53
  "legal_context",
54
  embedding_function=embedding_function
55
  )
56
+
57
+ # Init LLM
58
  genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
59
  self.llm = genai.GenerativeModel('gemini-2.5-flash-lite')
60
+
61
  def retrieve_graph_context(self, classification: Dict) -> Dict:
 
62
  graph_context = {
63
  'current_step': None,
64
  'next_steps': [],
65
  'resources': [],
66
  'process_overview': None
67
  }
 
68
  if classification['process'] == 'general':
69
  return graph_context
70
+
71
  process_name = classification['process_name']
72
+ current_step_id = classification['step'] or self.legal_graph.find_process_start(process_name)
73
+
74
+ if current_step_id and current_step_id in self.legal_graph.graph.nodes:
75
+ graph_context['current_step'] = {
76
+ 'id': current_step_id,
77
+ **self.legal_graph.graph.nodes[current_step_id]
78
+ }
79
+ graph_context['next_steps'] = self.legal_graph.get_next_steps(current_step_id)
80
+ graph_context['resources'] = self.legal_graph.get_required_resources(current_step_id)
81
+
 
 
 
 
 
 
 
 
 
 
 
82
  return graph_context
83
+
84
  def retrieve_vector_context(self, user_query: str, classification: Dict) -> List[Dict]:
 
 
 
85
  results = self.vector_collection.query(
86
  query_texts=[user_query],
87
  n_results=3,
88
  where={"process": classification.get('process_name', '')} if classification.get('process_name') else None
89
  )
 
90
  vector_context = []
91
  if results['documents']:
92
  for i in range(len(results['documents'][0])):
 
95
  'metadata': results['metadatas'][0][i],
96
  'distance': results['distances'][0][i] if 'distances' in results else None
97
  })
 
98
  return vector_context
99
+
100
+ def generate_response(self, user_query: str, graph_context: Dict,
101
+ vector_context: List[Dict], classification: Dict) -> str:
102
+ """
103
+ Generate Indian Law Assistant responses:
104
+ - Normal legal Q&A: Short, clear answers
105
+ - Complex / case-specific: Say consult professional
106
+ - If vector data is wrong: Override with correct info
107
+ """
108
+
109
  system_prompt = """
110
+ You are an **Indian Law Assistant**.
111
+
112
+ RULES:
113
+ - For general questions on Indian law (IPC, CrPC, FIR, consumer law, bail, etc.), answer briefly (1–3 lines).
114
+ - Use **sections of Indian laws (IPC, CrPC, Evidence Act, etc.)** where relevant.
115
+ - If query is very **complex / case-dependent**, politely suggest consulting a qualified lawyer, but still share general process info.
116
+ - If retrieved data looks irrelevant or wrong, override it with correct general legal knowledge of Indian laws.
117
+ - Be empathetic, clear, and concise.
118
+ - Use the same language as the user (Hindi/English mix if needed).
119
+ - Avoid long paragraphs. Use short, crisp answers with bullet points where possible.
 
 
120
  """
121
+
122
+ # Build context
123
  context_sections = []
 
 
124
  if graph_context['current_step']:
125
  context_sections.append(f"""
126
+ CURRENT STEP:
127
+ {graph_context['current_step']['title']} - {graph_context['current_step']['description']}
 
 
128
  """)
 
129
  if graph_context['next_steps']:
130
+ context_sections.append("NEXT STEPS:\n" + "\n".join([
131
+ f"- {s['title']}: {s['description']}" for s in graph_context['next_steps']
132
+ ]))
 
 
 
 
 
 
133
  if graph_context['resources']:
134
+ context_sections.append("RESOURCES:\n" + "\n".join([
135
+ f"- {r['title']} ({r['type']}): {r['properties'].get('url', r['properties'].get('phone', 'Contact available'))}"
136
+ for r in graph_context['resources']
137
+ ]))
 
 
 
 
 
 
138
  if vector_context:
139
  vector_text = "\n\n".join([doc['content'] for doc in vector_context])
140
+ context_sections.append("REFERENCE CONTEXT:\n" + vector_text)
141
+
 
 
 
 
142
  full_prompt = f"""
143
  {system_prompt}
144
+
145
  USER QUERY: "{user_query}"
146
+ CLASSIFICATION: Process: {classification.get('process_name', 'General')} | Intent: {classification.get('intent', 'information')}
147
+
148
+ CONTEXT:
149
  {chr(10).join(context_sections)}
150
+
151
+ Generate the best possible short, accurate, and user-friendly response.
 
 
 
 
 
 
 
152
  """
153
+
154
  try:
155
  response = self.llm.generate_content(full_prompt)
156
  return response.text
157
  except Exception as e:
158
+ return f"⚠️ System error. Please try again later or contact NALSA (nalsa-dla@nic.in). Error: {str(e)}"
159
+
160
  def process_query(self, user_query: str) -> Dict[str, Any]:
 
 
 
161
  classification = self.classifier.classify_query(user_query)
 
 
162
  graph_context = self.retrieve_graph_context(classification)
 
 
163
  vector_context = self.retrieve_vector_context(user_query, classification)
 
 
164
  response = self.generate_response(user_query, graph_context, vector_context, classification)
165
+
166
  return {
167
  'response': response,
168
  'classification': classification,
 
175
  }
176
  }
177
 
178
+
179
  if __name__ == "__main__":
180
  rag_system = ProcessAwareRAG()
181
+ test_query = "Under which IPC section is cheating punishable in India?"
 
 
182
  result = rag_system.process_query(test_query)
183
+
184
  print("=== QUERY ===")
185
  print(test_query)
186
+ print("\n=== RESPONSE ===")
187
  print(result['response'])
188
  print("\n=== DEBUG INFO ===")
189
+ print(result['debug_info'])