Spaces:
Running
Running
File size: 8,186 Bytes
dff2a93 8e453ef 7e9a861 8e453ef 737bc7d 8e453ef 737bc7d 074a13b a2b02a5 7e9a861 074a13b 7e9a861 457ebe6 7e9a861 074a13b 7e9a861 074a13b 7e9a861 074a13b a2b02a5 7e9a861 a2b02a5 7e9a861 a2b02a5 7e9a861 074a13b 8e453ef 2497a16 074a13b 8e453ef 074a13b 737bc7d 074a13b 8e453ef 074a13b 8e453ef 074a13b 737bc7d 074a13b 8e453ef 074a13b 8e453ef 074a13b 737bc7d 074a13b 8e453ef 074a13b 8e453ef 074a13b 8e453ef 074a13b 8e453ef 074a13b 8e453ef 737bc7d 074a13b 8e453ef 074a13b 8e453ef 074a13b 8e453ef 074a13b 8e453ef 074a13b 737bc7d 8e453ef 737bc7d 074a13b 8e453ef 737bc7d 8e453ef 074a13b 8e453ef 074a13b 737bc7d 074a13b 8e453ef 074a13b 8e453ef 074a13b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 |
import os
# Disable telemetry before importing chromadb
os.environ.setdefault("POSTHOG_DISABLED", "true")
os.environ.setdefault("CHROMA_TELEMETRY_DISABLED", "true")
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions
import google.generativeai as genai
from query_classifier import QueryClassifier
from graph_builder import LegalProcessGraph
from typing import Dict, List, Any
class ProcessAwareRAG:
def __init__(self):
# Initialize components
self.classifier = QueryClassifier()
self.legal_graph = LegalProcessGraph()
self.legal_graph.load_graph('legal_processes.pkl')
# Initialize vector store
chroma_path = os.getenv('CHROMA_DB_PATH', '/tmp/legal_vector_db')
os.makedirs(chroma_path, exist_ok=True)
# Redirect model caches
default_cache_root = os.getenv('CACHE_ROOT', '/data/cache')
os.environ.setdefault('HOME', '/data')
os.makedirs(default_cache_root, exist_ok=True)
for env_key in [
'HF_HOME', 'TRANSFORMERS_CACHE',
'SENTENCE_TRANSFORMERS_HOME', 'XDG_CACHE_HOME'
]:
os.environ.setdefault(env_key, os.path.join(default_cache_root, env_key.lower()))
os.makedirs(os.environ[env_key], exist_ok=True)
# Disable Chroma telemetry & init client
client = chromadb.PersistentClient(
path=chroma_path,
settings=Settings(anonymized_telemetry=False)
)
# Ensure collection exists
embedding_function = embedding_functions.DefaultEmbeddingFunction()
try:
self.vector_collection = client.get_collection(
"legal_context",
embedding_function=embedding_function
)
except Exception:
self.vector_collection = client.create_collection(
"legal_context",
embedding_function=embedding_function
)
# Init LLM
genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))
self.llm = genai.GenerativeModel('gemini-2.5-flash-lite')
def retrieve_graph_context(self, classification: Dict) -> Dict:
graph_context = {
'current_step': None,
'next_steps': [],
'resources': [],
'process_overview': None
}
if classification['process'] == 'general':
return graph_context
process_name = classification['process_name']
current_step_id = classification['step'] or self.legal_graph.find_process_start(process_name)
if current_step_id and current_step_id in self.legal_graph.graph.nodes:
graph_context['current_step'] = {
'id': current_step_id,
**self.legal_graph.graph.nodes[current_step_id]
}
graph_context['next_steps'] = self.legal_graph.get_next_steps(current_step_id)
graph_context['resources'] = self.legal_graph.get_required_resources(current_step_id)
return graph_context
def retrieve_vector_context(self, user_query: str, classification: Dict) -> List[Dict]:
results = self.vector_collection.query(
query_texts=[user_query],
n_results=3,
where={"process": classification.get('process_name', '')} if classification.get('process_name') else None
)
vector_context = []
if results['documents']:
for i in range(len(results['documents'][0])):
vector_context.append({
'content': results['documents'][0][i],
'metadata': results['metadatas'][0][i],
'distance': results['distances'][0][i] if 'distances' in results else None
})
return vector_context
def generate_response(self, user_query: str, history: List[Dict],
graph_context: Dict, vector_context: List[Dict],
classification: Dict) -> str:
"""
Generate Indian Law Assistant responses:
- Normal legal Q&A: Short, clear answers
- Complex / case-specific: Say consult professional
- If vector data is wrong: Override with correct info
"""
system_prompt = """
You are an **Indian Law Assistant**.
RULES:
- For general questions on Indian law (IPC, CrPC, FIR, consumer law, bail, etc.), answer briefly (1–3 lines).
- Use **sections of Indian laws (IPC, CrPC, Evidence Act, etc.)** where relevant.
- If query is very **complex / case-dependent**, politely suggest consulting a qualified lawyer, but still share general process info.
- If retrieved data looks irrelevant or wrong, override it with correct general legal knowledge of Indian laws.
- Be empathetic, clear, and concise.
- Use the same language as the user (Hindi/English mix if needed).
- Avoid long paragraphs. Use short, crisp answers with bullet points where possible.
"""
# Format chat history (use last 5 messages for context)
history_text = "\n".join([f"{m['role']}: {m['content']}" for m in history[-5:]]) if history else ""
# Build context
context_sections = []
if graph_context['current_step']:
context_sections.append(f"""
CURRENT STEP:
{graph_context['current_step']['title']} - {graph_context['current_step']['description']}
""")
if graph_context['next_steps']:
context_sections.append("NEXT STEPS:\n" + "\n".join([
f"- {s['title']}: {s['description']}" for s in graph_context['next_steps']
]))
if graph_context['resources']:
context_sections.append("RESOURCES:\n" + "\n".join([
f"- {r['title']} ({r['type']}): {r['properties'].get('url', r['properties'].get('phone', 'Contact available'))}"
for r in graph_context['resources']
]))
if vector_context:
vector_text = "\n\n".join([doc['content'] for doc in vector_context])
context_sections.append("REFERENCE CONTEXT:\n" + vector_text)
full_prompt = f"""
{system_prompt}
CHAT HISTORY:
{history_text}
USER QUERY: "{user_query}"
CLASSIFICATION: Process: {classification.get('process_name', 'General')} | Intent: {classification.get('intent', 'information')}
CONTEXT:
{chr(10).join(context_sections)}
Generate the best possible short, accurate, and user-friendly response.
"""
try:
response = self.llm.generate_content(full_prompt)
return response.text
except Exception as e:
return f"⚠️ System error. Please try again later or contact NALSA (nalsa-dla@nic.in). Error: {str(e)}"
def process_query(self, user_query: str, history: List[Dict]) -> Dict[str, Any]:
classification = self.classifier.classify_query(user_query)
graph_context = self.retrieve_graph_context(classification)
vector_context = self.retrieve_vector_context(user_query, classification)
response = self.generate_response(user_query, history, graph_context, vector_context, classification)
return {
'response': response,
'classification': classification,
'graph_context': graph_context,
'vector_context': vector_context,
'debug_info': {
'graph_nodes_found': len(graph_context.get('next_steps', [])),
'vector_docs_found': len(vector_context),
'process_identified': classification.get('process_name'),
'history_length': len(history)
}
}
if __name__ == "__main__":
rag_system = ProcessAwareRAG()
test_query = "Under which IPC section is cheating punishable in India?"
result = rag_system.process_query(test_query, history=[])
print("=== QUERY ===")
print(test_query)
print("\n=== RESPONSE ===")
print(result['response'])
print("\n=== DEBUG INFO ===")
print(result['debug_info'])
|