Spaces:
Running
Running
Update process_aware_rag.py
Browse files- process_aware_rag.py +72 -131
process_aware_rag.py
CHANGED
|
@@ -20,31 +20,28 @@ class ProcessAwareRAG:
|
|
| 20 |
self.legal_graph = LegalProcessGraph()
|
| 21 |
self.legal_graph.load_graph('legal_processes.pkl')
|
| 22 |
|
| 23 |
-
# Initialize vector store
|
| 24 |
chroma_path = os.getenv('CHROMA_DB_PATH', '/tmp/legal_vector_db')
|
| 25 |
os.makedirs(chroma_path, exist_ok=True)
|
| 26 |
|
| 27 |
-
# Redirect model caches
|
| 28 |
default_cache_root = os.getenv('CACHE_ROOT', '/data/cache')
|
| 29 |
os.environ.setdefault('HOME', '/data')
|
| 30 |
os.makedirs(default_cache_root, exist_ok=True)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
os.environ.setdefault('XDG_CACHE_HOME', default_cache_root)
|
| 37 |
-
for env_key in ['HF_HOME', 'TRANSFORMERS_CACHE', 'SENTENCE_TRANSFORMERS_HOME', 'XDG_CACHE_HOME']:
|
| 38 |
os.makedirs(os.environ[env_key], exist_ok=True)
|
| 39 |
|
| 40 |
-
# Disable Chroma
|
| 41 |
client = chromadb.PersistentClient(
|
| 42 |
path=chroma_path,
|
| 43 |
settings=Settings(anonymized_telemetry=False)
|
| 44 |
)
|
| 45 |
-
|
| 46 |
# Ensure collection exists
|
| 47 |
-
# Use explicit embedding function to ensure queries can compute embeddings
|
| 48 |
embedding_function = embedding_functions.DefaultEmbeddingFunction()
|
| 49 |
try:
|
| 50 |
self.vector_collection = client.get_collection(
|
|
@@ -56,57 +53,40 @@ class ProcessAwareRAG:
|
|
| 56 |
"legal_context",
|
| 57 |
embedding_function=embedding_function
|
| 58 |
)
|
| 59 |
-
|
| 60 |
-
#
|
| 61 |
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
|
| 62 |
self.llm = genai.GenerativeModel('gemini-2.5-flash-lite')
|
| 63 |
-
|
| 64 |
def retrieve_graph_context(self, classification: Dict) -> Dict:
|
| 65 |
-
"""Retrieve relevant context from knowledge graph"""
|
| 66 |
graph_context = {
|
| 67 |
'current_step': None,
|
| 68 |
'next_steps': [],
|
| 69 |
'resources': [],
|
| 70 |
'process_overview': None
|
| 71 |
}
|
| 72 |
-
|
| 73 |
if classification['process'] == 'general':
|
| 74 |
return graph_context
|
| 75 |
-
|
| 76 |
process_name = classification['process_name']
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
if
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
graph_context['current_step'] = {
|
| 88 |
-
'id': current_step_id,
|
| 89 |
-
**self.legal_graph.graph.nodes[current_step_id]
|
| 90 |
-
}
|
| 91 |
-
|
| 92 |
-
# Get next steps
|
| 93 |
-
graph_context['next_steps'] = self.legal_graph.get_next_steps(current_step_id)
|
| 94 |
-
|
| 95 |
-
# Get relevant resources
|
| 96 |
-
graph_context['resources'] = self.legal_graph.get_required_resources(current_step_id)
|
| 97 |
-
|
| 98 |
return graph_context
|
| 99 |
-
|
| 100 |
def retrieve_vector_context(self, user_query: str, classification: Dict) -> List[Dict]:
|
| 101 |
-
"""Retrieve relevant context from vector store"""
|
| 102 |
-
|
| 103 |
-
# Query vector store
|
| 104 |
results = self.vector_collection.query(
|
| 105 |
query_texts=[user_query],
|
| 106 |
n_results=3,
|
| 107 |
where={"process": classification.get('process_name', '')} if classification.get('process_name') else None
|
| 108 |
)
|
| 109 |
-
|
| 110 |
vector_context = []
|
| 111 |
if results['documents']:
|
| 112 |
for i in range(len(results['documents'][0])):
|
|
@@ -115,109 +95,74 @@ class ProcessAwareRAG:
|
|
| 115 |
'metadata': results['metadatas'][0][i],
|
| 116 |
'distance': results['distances'][0][i] if 'distances' in results else None
|
| 117 |
})
|
| 118 |
-
|
| 119 |
return vector_context
|
| 120 |
-
|
| 121 |
-
def generate_response(self, user_query: str, graph_context: Dict,
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
system_prompt = """
|
| 126 |
-
You are
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
-
|
| 130 |
-
-
|
| 131 |
-
-
|
| 132 |
-
-
|
| 133 |
-
- Be empathetic and
|
| 134 |
-
- Use
|
| 135 |
-
-
|
| 136 |
-
- And Must be Short and Clear Not To give Like long Long Para answers Answer must be Short and Clear
|
| 137 |
-
- Dont Give Large Answer Give Answer In One or two three Lines answers Give Accordingly
|
| 138 |
"""
|
| 139 |
-
|
| 140 |
-
#
|
| 141 |
context_sections = []
|
| 142 |
-
|
| 143 |
-
# Add graph context
|
| 144 |
if graph_context['current_step']:
|
| 145 |
context_sections.append(f"""
|
| 146 |
-
CURRENT
|
| 147 |
-
|
| 148 |
-
Description: {graph_context['current_step']['description']}
|
| 149 |
-
Properties: {graph_context['current_step'].get('properties', {})}
|
| 150 |
""")
|
| 151 |
-
|
| 152 |
if graph_context['next_steps']:
|
| 153 |
-
|
| 154 |
-
f"- {
|
| 155 |
-
|
| 156 |
-
])
|
| 157 |
-
context_sections.append(f"""
|
| 158 |
-
NEXT STEPS:
|
| 159 |
-
{next_steps_text}
|
| 160 |
-
""")
|
| 161 |
-
|
| 162 |
if graph_context['resources']:
|
| 163 |
-
|
| 164 |
-
f"- {
|
| 165 |
-
for
|
| 166 |
-
])
|
| 167 |
-
context_sections.append(f"""
|
| 168 |
-
RELEVANT RESOURCES:
|
| 169 |
-
{resources_text}
|
| 170 |
-
""")
|
| 171 |
-
|
| 172 |
-
# Add vector context
|
| 173 |
if vector_context:
|
| 174 |
vector_text = "\n\n".join([doc['content'] for doc in vector_context])
|
| 175 |
-
context_sections.append(
|
| 176 |
-
|
| 177 |
-
{vector_text}
|
| 178 |
-
""")
|
| 179 |
-
|
| 180 |
-
# Build final prompt
|
| 181 |
full_prompt = f"""
|
| 182 |
{system_prompt}
|
| 183 |
-
|
| 184 |
USER QUERY: "{user_query}"
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
{chr(10).join(context_sections)}
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
1. Acknowledges the user's situation empathetically
|
| 192 |
-
2. Provides specific next steps if this is a process guidance request
|
| 193 |
-
3. Includes relevant official links and contact information
|
| 194 |
-
4. Reminds the user to consult a lawyer for specific legal advice
|
| 195 |
-
5. Uses bullet points, bold formatting, and clear structure
|
| 196 |
-
|
| 197 |
-
Format your response with clear sections and actionable information.
|
| 198 |
"""
|
| 199 |
-
|
| 200 |
try:
|
| 201 |
response = self.llm.generate_content(full_prompt)
|
| 202 |
return response.text
|
| 203 |
except Exception as e:
|
| 204 |
-
return f"
|
| 205 |
-
|
| 206 |
def process_query(self, user_query: str) -> Dict[str, Any]:
|
| 207 |
-
"""Main pipeline: process user query end-to-end"""
|
| 208 |
-
|
| 209 |
-
# Step 1: Classify query
|
| 210 |
classification = self.classifier.classify_query(user_query)
|
| 211 |
-
|
| 212 |
-
# Step 2: Retrieve graph context
|
| 213 |
graph_context = self.retrieve_graph_context(classification)
|
| 214 |
-
|
| 215 |
-
# Step 3: Retrieve vector context
|
| 216 |
vector_context = self.retrieve_vector_context(user_query, classification)
|
| 217 |
-
|
| 218 |
-
# Step 4: Generate response
|
| 219 |
response = self.generate_response(user_query, graph_context, vector_context, classification)
|
| 220 |
-
|
| 221 |
return {
|
| 222 |
'response': response,
|
| 223 |
'classification': classification,
|
|
@@ -230,19 +175,15 @@ class ProcessAwareRAG:
|
|
| 230 |
}
|
| 231 |
}
|
| 232 |
|
| 233 |
-
|
| 234 |
if __name__ == "__main__":
|
| 235 |
rag_system = ProcessAwareRAG()
|
| 236 |
-
|
| 237 |
-
test_query = "I need free legal help but I'm not sure if I qualify. My monthly income is around 45000 rupees."
|
| 238 |
-
|
| 239 |
result = rag_system.process_query(test_query)
|
| 240 |
-
|
| 241 |
print("=== QUERY ===")
|
| 242 |
print(test_query)
|
| 243 |
-
print("\n=== RESPONSE ===")
|
| 244 |
print(result['response'])
|
| 245 |
print("\n=== DEBUG INFO ===")
|
| 246 |
-
print(
|
| 247 |
-
print(f"Graph nodes: {result['debug_info']['graph_nodes_found']}")
|
| 248 |
-
print(f"Vector docs: {result['debug_info']['vector_docs_found']}")
|
|
|
|
| 20 |
self.legal_graph = LegalProcessGraph()
|
| 21 |
self.legal_graph.load_graph('legal_processes.pkl')
|
| 22 |
|
| 23 |
+
# Initialize vector store
|
| 24 |
chroma_path = os.getenv('CHROMA_DB_PATH', '/tmp/legal_vector_db')
|
| 25 |
os.makedirs(chroma_path, exist_ok=True)
|
| 26 |
|
| 27 |
+
# Redirect model caches
|
| 28 |
default_cache_root = os.getenv('CACHE_ROOT', '/data/cache')
|
| 29 |
os.environ.setdefault('HOME', '/data')
|
| 30 |
os.makedirs(default_cache_root, exist_ok=True)
|
| 31 |
+
for env_key in [
|
| 32 |
+
'HF_HOME', 'TRANSFORMERS_CACHE',
|
| 33 |
+
'SENTENCE_TRANSFORMERS_HOME', 'XDG_CACHE_HOME'
|
| 34 |
+
]:
|
| 35 |
+
os.environ.setdefault(env_key, os.path.join(default_cache_root, env_key.lower()))
|
|
|
|
|
|
|
| 36 |
os.makedirs(os.environ[env_key], exist_ok=True)
|
| 37 |
|
| 38 |
+
# Disable Chroma telemetry & init client
|
| 39 |
client = chromadb.PersistentClient(
|
| 40 |
path=chroma_path,
|
| 41 |
settings=Settings(anonymized_telemetry=False)
|
| 42 |
)
|
| 43 |
+
|
| 44 |
# Ensure collection exists
|
|
|
|
| 45 |
embedding_function = embedding_functions.DefaultEmbeddingFunction()
|
| 46 |
try:
|
| 47 |
self.vector_collection = client.get_collection(
|
|
|
|
| 53 |
"legal_context",
|
| 54 |
embedding_function=embedding_function
|
| 55 |
)
|
| 56 |
+
|
| 57 |
+
# Init LLM
|
| 58 |
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
|
| 59 |
self.llm = genai.GenerativeModel('gemini-2.5-flash-lite')
|
| 60 |
+
|
| 61 |
def retrieve_graph_context(self, classification: Dict) -> Dict:
|
|
|
|
| 62 |
graph_context = {
|
| 63 |
'current_step': None,
|
| 64 |
'next_steps': [],
|
| 65 |
'resources': [],
|
| 66 |
'process_overview': None
|
| 67 |
}
|
|
|
|
| 68 |
if classification['process'] == 'general':
|
| 69 |
return graph_context
|
| 70 |
+
|
| 71 |
process_name = classification['process_name']
|
| 72 |
+
current_step_id = classification['step'] or self.legal_graph.find_process_start(process_name)
|
| 73 |
+
|
| 74 |
+
if current_step_id and current_step_id in self.legal_graph.graph.nodes:
|
| 75 |
+
graph_context['current_step'] = {
|
| 76 |
+
'id': current_step_id,
|
| 77 |
+
**self.legal_graph.graph.nodes[current_step_id]
|
| 78 |
+
}
|
| 79 |
+
graph_context['next_steps'] = self.legal_graph.get_next_steps(current_step_id)
|
| 80 |
+
graph_context['resources'] = self.legal_graph.get_required_resources(current_step_id)
|
| 81 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
return graph_context
|
| 83 |
+
|
| 84 |
def retrieve_vector_context(self, user_query: str, classification: Dict) -> List[Dict]:
|
|
|
|
|
|
|
|
|
|
| 85 |
results = self.vector_collection.query(
|
| 86 |
query_texts=[user_query],
|
| 87 |
n_results=3,
|
| 88 |
where={"process": classification.get('process_name', '')} if classification.get('process_name') else None
|
| 89 |
)
|
|
|
|
| 90 |
vector_context = []
|
| 91 |
if results['documents']:
|
| 92 |
for i in range(len(results['documents'][0])):
|
|
|
|
| 95 |
'metadata': results['metadatas'][0][i],
|
| 96 |
'distance': results['distances'][0][i] if 'distances' in results else None
|
| 97 |
})
|
|
|
|
| 98 |
return vector_context
|
| 99 |
+
|
| 100 |
+
def generate_response(self, user_query: str, graph_context: Dict,
|
| 101 |
+
vector_context: List[Dict], classification: Dict) -> str:
|
| 102 |
+
"""
|
| 103 |
+
Generate Indian Law Assistant responses:
|
| 104 |
+
- Normal legal Q&A: Short, clear answers
|
| 105 |
+
- Complex / case-specific: Say consult professional
|
| 106 |
+
- If vector data is wrong: Override with correct info
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
system_prompt = """
|
| 110 |
+
You are an **Indian Law Assistant**.
|
| 111 |
+
|
| 112 |
+
RULES:
|
| 113 |
+
- For general questions on Indian law (IPC, CrPC, FIR, consumer law, bail, etc.), answer briefly (1–3 lines).
|
| 114 |
+
- Use **sections of Indian laws (IPC, CrPC, Evidence Act, etc.)** where relevant.
|
| 115 |
+
- If query is very **complex / case-dependent**, politely suggest consulting a qualified lawyer, but still share general process info.
|
| 116 |
+
- If retrieved data looks irrelevant or wrong, override it with correct general legal knowledge of Indian laws.
|
| 117 |
+
- Be empathetic, clear, and concise.
|
| 118 |
+
- Use the same language as the user (Hindi/English mix if needed).
|
| 119 |
+
- Avoid long paragraphs. Use short, crisp answers with bullet points where possible.
|
|
|
|
|
|
|
| 120 |
"""
|
| 121 |
+
|
| 122 |
+
# Build context
|
| 123 |
context_sections = []
|
|
|
|
|
|
|
| 124 |
if graph_context['current_step']:
|
| 125 |
context_sections.append(f"""
|
| 126 |
+
CURRENT STEP:
|
| 127 |
+
{graph_context['current_step']['title']} - {graph_context['current_step']['description']}
|
|
|
|
|
|
|
| 128 |
""")
|
|
|
|
| 129 |
if graph_context['next_steps']:
|
| 130 |
+
context_sections.append("NEXT STEPS:\n" + "\n".join([
|
| 131 |
+
f"- {s['title']}: {s['description']}" for s in graph_context['next_steps']
|
| 132 |
+
]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
if graph_context['resources']:
|
| 134 |
+
context_sections.append("RESOURCES:\n" + "\n".join([
|
| 135 |
+
f"- {r['title']} ({r['type']}): {r['properties'].get('url', r['properties'].get('phone', 'Contact available'))}"
|
| 136 |
+
for r in graph_context['resources']
|
| 137 |
+
]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
if vector_context:
|
| 139 |
vector_text = "\n\n".join([doc['content'] for doc in vector_context])
|
| 140 |
+
context_sections.append("REFERENCE CONTEXT:\n" + vector_text)
|
| 141 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
full_prompt = f"""
|
| 143 |
{system_prompt}
|
| 144 |
+
|
| 145 |
USER QUERY: "{user_query}"
|
| 146 |
+
CLASSIFICATION: Process: {classification.get('process_name', 'General')} | Intent: {classification.get('intent', 'information')}
|
| 147 |
+
|
| 148 |
+
CONTEXT:
|
| 149 |
{chr(10).join(context_sections)}
|
| 150 |
+
|
| 151 |
+
Generate the best possible short, accurate, and user-friendly response.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
"""
|
| 153 |
+
|
| 154 |
try:
|
| 155 |
response = self.llm.generate_content(full_prompt)
|
| 156 |
return response.text
|
| 157 |
except Exception as e:
|
| 158 |
+
return f"⚠️ System error. Please try again later or contact NALSA (nalsa-dla@nic.in). Error: {str(e)}"
|
| 159 |
+
|
| 160 |
def process_query(self, user_query: str) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
| 161 |
classification = self.classifier.classify_query(user_query)
|
|
|
|
|
|
|
| 162 |
graph_context = self.retrieve_graph_context(classification)
|
|
|
|
|
|
|
| 163 |
vector_context = self.retrieve_vector_context(user_query, classification)
|
|
|
|
|
|
|
| 164 |
response = self.generate_response(user_query, graph_context, vector_context, classification)
|
| 165 |
+
|
| 166 |
return {
|
| 167 |
'response': response,
|
| 168 |
'classification': classification,
|
|
|
|
| 175 |
}
|
| 176 |
}
|
| 177 |
|
| 178 |
+
|
| 179 |
if __name__ == "__main__":
|
| 180 |
rag_system = ProcessAwareRAG()
|
| 181 |
+
test_query = "Under which IPC section is cheating punishable in India?"
|
|
|
|
|
|
|
| 182 |
result = rag_system.process_query(test_query)
|
| 183 |
+
|
| 184 |
print("=== QUERY ===")
|
| 185 |
print(test_query)
|
| 186 |
+
print("\n=== RESPONSE ===")
|
| 187 |
print(result['response'])
|
| 188 |
print("\n=== DEBUG INFO ===")
|
| 189 |
+
print(result['debug_info'])
|
|
|
|
|
|