yamraj047 commited on
Commit
37039ca
·
verified ·
1 Parent(s): 349be52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -49
app.py CHANGED
@@ -10,21 +10,19 @@ from groq import Groq
10
  import os
11
  from typing import List, Dict, Optional
12
  import logging
 
13
 
14
- # Configure logging
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
- # Initialize FastAPI app
19
  app = FastAPI(
20
  title="LexNepal AI API",
21
- description="Advanced Legal Intelligence API for Nepal Legal Code using RAG",
22
  version="1.0.0",
23
  docs_url="/",
24
  redoc_url="/redoc"
25
  )
26
 
27
- # CORS middleware
28
  app.add_middleware(
29
  CORSMiddleware,
30
  allow_origins=["*"],
@@ -33,7 +31,6 @@ app.add_middleware(
33
  allow_headers=["*"],
34
  )
35
 
36
- # Pydantic models
37
  class QueryRequest(BaseModel):
38
  query: str
39
  max_sources: Optional[int] = 10
@@ -62,61 +59,94 @@ class StatsResponse(BaseModel):
62
  class HealthResponse(BaseModel):
63
  status: str
64
  models_loaded: bool
 
65
 
66
- # Global variables - lazy loading
67
  _bi_encoder = None
68
  _cross_encoder = None
69
  _groq_client = None
70
  _index = None
71
  _metadata = None
72
- _embeddings = None
73
 
74
  def get_bi_encoder():
75
- """Lazy load bi-encoder"""
76
  global _bi_encoder
77
  if _bi_encoder is None:
78
- logger.info("Loading bi-encoder...")
79
  _bi_encoder = SentenceTransformer("all-mpnet-base-v2")
 
80
  return _bi_encoder
81
 
82
  def get_cross_encoder():
83
- """Lazy load cross-encoder"""
84
  global _cross_encoder
85
  if _cross_encoder is None:
86
  logger.info("Loading cross-encoder...")
87
  _cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
 
88
  return _cross_encoder
89
 
90
  def get_groq_client():
91
- """Lazy load Groq client"""
92
  global _groq_client
93
  if _groq_client is None:
94
  logger.info("Initializing Groq client...")
95
  groq_api_key = os.getenv("GROQ_API_KEY", "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")
96
- _groq_client = Groq(api_key=groq_api_key)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  return _groq_client
98
 
99
  def get_index():
100
- """Lazy load FAISS index"""
101
- global _index, _embeddings
102
  if _index is None:
103
- logger.info("Loading embeddings and creating index...")
104
- _embeddings = np.load("final_legal_embeddings.npy")
105
- _index = faiss.IndexFlatL2(_embeddings.shape[1])
106
- _index.add(_embeddings.astype('float32'))
 
 
 
 
 
 
 
 
 
107
  return _index
108
 
109
  def get_metadata():
110
- """Lazy load metadata"""
111
  global _metadata
112
  if _metadata is None:
113
  logger.info("Loading metadata...")
114
- with open("final_legal_laws_metadata.json", "r", encoding="utf-8") as f:
115
- _metadata = json.load(f)
 
 
 
 
 
 
 
 
116
  return _metadata
117
 
118
  def get_premium_context(query: str, max_sources: int = 10) -> List[Dict]:
119
- """Hybrid retrieval with cross-encoder reranking"""
120
  try:
121
  bi_encoder = get_bi_encoder()
122
  cross_encoder = get_cross_encoder()
@@ -159,38 +189,54 @@ def get_premium_context(query: str, max_sources: int = 10) -> List[Dict]:
159
 
160
  candidates = sorted(candidates, key=lambda x: x['rel_score'], reverse=True)[:max_sources]
161
 
 
162
  return candidates
163
 
164
  except Exception as e:
165
  logger.error(f"Error in context retrieval: {str(e)}")
166
- return []
167
 
168
  @app.get("/health", response_model=HealthResponse)
169
  async def health_check():
170
  """Health check endpoint"""
 
 
 
 
 
 
 
 
171
  return {
172
- "status": "healthy",
173
- "models_loaded": True
 
174
  }
175
 
176
  @app.get("/stats", response_model=StatsResponse)
177
  async def get_statistics():
178
  """Get database statistics"""
179
- metadata = get_metadata()
180
- unique_laws = len(set(d.get('law', '') for d in metadata))
181
-
182
- return {
183
- "total_provisions": len(metadata),
184
- "total_laws": unique_laws,
185
- "vector_dimensions": 768,
186
- "embedding_model": "all-mpnet-base-v2",
187
- "reranking_model": "ms-marco-MiniLM-L-6-v2",
188
- "llm_model": "llama-3.3-70b-versatile"
189
- }
 
 
 
 
190
 
191
  @app.post("/query", response_model=QueryResponse)
192
  async def process_legal_query(request: QueryRequest):
193
  """Process legal query with RAG pipeline"""
 
 
194
  if not request.query.strip():
195
  raise HTTPException(status_code=400, detail="Query cannot be empty")
196
 
@@ -205,13 +251,13 @@ async def process_legal_query(request: QueryRequest):
205
 
206
  if not candidates:
207
  return {
208
- "answer": "No relevant legal provisions found in the database for your query.",
209
  "sources": [],
210
  "query": request.query,
211
  "total_candidates": 0
212
  }
213
 
214
- # Build context
215
  context_str = "\n\n".join([
216
  f"[{d['law']} Section {d['section']}]: {d['text']}"
217
  for d in candidates
@@ -222,17 +268,19 @@ async def process_legal_query(request: QueryRequest):
222
 
223
  OPERATIONAL MANDATE:
224
  1. Answer STRICTLY from provided legal text
225
- 2. If information is absent, state: "No specific provision found"
226
  3. Always cite exact Law name and Section number
227
  4. Use formal, authoritative legal language
228
  5. NEVER hallucinate or infer beyond provided text
229
  6. Maintain zero-tolerance policy for speculation
230
 
231
- Format: "According to [Law], Section [Number]..."
232
- """
233
 
234
- # Generate response
 
235
  groq_client = get_groq_client()
 
236
  response = groq_client.chat.completions.create(
237
  model="llama-3.3-70b-versatile",
238
  messages=[
@@ -257,6 +305,8 @@ Format: "According to [Law], Section [Number]..."
257
  for d in candidates
258
  ]
259
 
 
 
260
  return {
261
  "answer": answer,
262
  "sources": sources,
@@ -264,19 +314,27 @@ Format: "According to [Law], Section [Number]..."
264
  "total_candidates": len(candidates)
265
  }
266
 
 
 
267
  except Exception as e:
268
- logger.error(f"Error: {str(e)}")
269
- raise HTTPException(status_code=500, detail=str(e))
270
 
271
  @app.get("/")
272
  async def root():
273
- """Root endpoint - redirect to docs"""
274
  return {
275
- "message": "LexNepal AI API",
276
  "version": "1.0.0",
277
- "docs": "/docs",
278
- "health": "/health",
279
- "stats": "/stats"
 
 
 
 
 
 
280
  }
281
 
282
  if __name__ == "__main__":
 
10
  import os
11
  from typing import List, Dict, Optional
12
  import logging
13
+ import httpx
14
 
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
 
18
  app = FastAPI(
19
  title="LexNepal AI API",
20
+ description="Advanced Legal Intelligence API for Nepal Legal Code",
21
  version="1.0.0",
22
  docs_url="/",
23
  redoc_url="/redoc"
24
  )
25
 
 
26
  app.add_middleware(
27
  CORSMiddleware,
28
  allow_origins=["*"],
 
31
  allow_headers=["*"],
32
  )
33
 
 
34
  class QueryRequest(BaseModel):
35
  query: str
36
  max_sources: Optional[int] = 10
 
59
  class HealthResponse(BaseModel):
60
  status: str
61
  models_loaded: bool
62
+ message: Optional[str] = None
63
 
 
64
  _bi_encoder = None
65
  _cross_encoder = None
66
  _groq_client = None
67
  _index = None
68
  _metadata = None
 
69
 
70
  def get_bi_encoder():
 
71
  global _bi_encoder
72
  if _bi_encoder is None:
73
+ logger.info("Loading bi-encoder (MPNet)...")
74
  _bi_encoder = SentenceTransformer("all-mpnet-base-v2")
75
+ logger.info("✅ Bi-encoder loaded successfully")
76
  return _bi_encoder
77
 
78
  def get_cross_encoder():
 
79
  global _cross_encoder
80
  if _cross_encoder is None:
81
  logger.info("Loading cross-encoder...")
82
  _cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
83
+ logger.info("✅ Cross-encoder loaded successfully")
84
  return _cross_encoder
85
 
86
  def get_groq_client():
 
87
  global _groq_client
88
  if _groq_client is None:
89
  logger.info("Initializing Groq client...")
90
  groq_api_key = os.getenv("GROQ_API_KEY", "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")
91
+
92
+ try:
93
+ # Try standard initialization
94
+ _groq_client = Groq(api_key=groq_api_key)
95
+ logger.info("✅ Groq client initialized (standard)")
96
+ except TypeError as e:
97
+ logger.warning(f"Standard Groq init failed: {e}, trying with custom HTTP client...")
98
+ try:
99
+ # Fallback with custom HTTP client
100
+ http_client = httpx.Client(timeout=60.0)
101
+ _groq_client = Groq(
102
+ api_key=groq_api_key,
103
+ http_client=http_client
104
+ )
105
+ logger.info("✅ Groq client initialized (with custom HTTP client)")
106
+ except Exception as e2:
107
+ logger.error(f"❌ Failed to initialize Groq client: {e2}")
108
+ raise HTTPException(
109
+ status_code=503,
110
+ detail=f"Failed to initialize Groq client: {str(e2)}"
111
+ )
112
+
113
  return _groq_client
114
 
115
  def get_index():
116
+ global _index
 
117
  if _index is None:
118
+ logger.info("Loading embeddings and creating FAISS index...")
119
+ try:
120
+ embeddings = np.load("final_legal_embeddings.npy")
121
+ logger.info(f"Embeddings shape: {embeddings.shape}")
122
+ _index = faiss.IndexFlatL2(embeddings.shape[1])
123
+ _index.add(embeddings.astype('float32'))
124
+ logger.info(f"✅ FAISS index created with {embeddings.shape[0]} vectors")
125
+ except FileNotFoundError:
126
+ logger.error("❌ Embeddings file not found")
127
+ raise HTTPException(
128
+ status_code=503,
129
+ detail="Embeddings file not found. Please upload final_legal_embeddings.npy"
130
+ )
131
  return _index
132
 
133
  def get_metadata():
 
134
  global _metadata
135
  if _metadata is None:
136
  logger.info("Loading metadata...")
137
+ try:
138
+ with open("final_legal_laws_metadata.json", "r", encoding="utf-8") as f:
139
+ _metadata = json.load(f)
140
+ logger.info(f"✅ Loaded {len(_metadata)} legal provisions")
141
+ except FileNotFoundError:
142
+ logger.error("❌ Metadata file not found")
143
+ raise HTTPException(
144
+ status_code=503,
145
+ detail="Metadata file not found. Please upload final_legal_laws_metadata.json"
146
+ )
147
  return _metadata
148
 
149
  def get_premium_context(query: str, max_sources: int = 10) -> List[Dict]:
 
150
  try:
151
  bi_encoder = get_bi_encoder()
152
  cross_encoder = get_cross_encoder()
 
189
 
190
  candidates = sorted(candidates, key=lambda x: x['rel_score'], reverse=True)[:max_sources]
191
 
192
+ logger.info(f"Retrieved {len(candidates)} relevant candidates")
193
  return candidates
194
 
195
  except Exception as e:
196
  logger.error(f"Error in context retrieval: {str(e)}")
197
+ raise HTTPException(status_code=500, detail=f"Context retrieval error: {str(e)}")
198
 
199
  @app.get("/health", response_model=HealthResponse)
200
  async def health_check():
201
  """Health check endpoint"""
202
+ try:
203
+ metadata = get_metadata()
204
+ models_loaded = True
205
+ message = f"API is healthy. {len(metadata)} provisions loaded."
206
+ except Exception as e:
207
+ models_loaded = False
208
+ message = f"Error: {str(e)}"
209
+
210
  return {
211
+ "status": "healthy" if models_loaded else "unhealthy",
212
+ "models_loaded": models_loaded,
213
+ "message": message
214
  }
215
 
216
  @app.get("/stats", response_model=StatsResponse)
217
  async def get_statistics():
218
  """Get database statistics"""
219
+ try:
220
+ metadata = get_metadata()
221
+ unique_laws = len(set(d.get('law', '') for d in metadata))
222
+
223
+ return {
224
+ "total_provisions": len(metadata),
225
+ "total_laws": unique_laws,
226
+ "vector_dimensions": 768,
227
+ "embedding_model": "all-mpnet-base-v2",
228
+ "reranking_model": "ms-marco-MiniLM-L-6-v2",
229
+ "llm_model": "llama-3.3-70b-versatile"
230
+ }
231
+ except Exception as e:
232
+ logger.error(f"Error getting stats: {str(e)}")
233
+ raise HTTPException(status_code=503, detail=str(e))
234
 
235
  @app.post("/query", response_model=QueryResponse)
236
  async def process_legal_query(request: QueryRequest):
237
  """Process legal query with RAG pipeline"""
238
+
239
+ # Validation
240
  if not request.query.strip():
241
  raise HTTPException(status_code=400, detail="Query cannot be empty")
242
 
 
251
 
252
  if not candidates:
253
  return {
254
+ "answer": "No relevant legal provisions found in the database for your query. Please try rephrasing or consult a legal professional.",
255
  "sources": [],
256
  "query": request.query,
257
  "total_candidates": 0
258
  }
259
 
260
+ # Build context string
261
  context_str = "\n\n".join([
262
  f"[{d['law']} Section {d['section']}]: {d['text']}"
263
  for d in candidates
 
268
 
269
  OPERATIONAL MANDATE:
270
  1. Answer STRICTLY from provided legal text
271
+ 2. If information is absent, state: "No specific provision found in current database"
272
  3. Always cite exact Law name and Section number
273
  4. Use formal, authoritative legal language
274
  5. NEVER hallucinate or infer beyond provided text
275
  6. Maintain zero-tolerance policy for speculation
276
 
277
+ When citing, use format: "According to [Law Name], Section [Number]..."
278
+ Provide clear, structured answers with proper legal citations."""
279
 
280
+ # Generate response using Groq
281
+ logger.info("Generating LLM response...")
282
  groq_client = get_groq_client()
283
+
284
  response = groq_client.chat.completions.create(
285
  model="llama-3.3-70b-versatile",
286
  messages=[
 
305
  for d in candidates
306
  ]
307
 
308
+ logger.info(f"✅ Query processed successfully with {len(sources)} sources")
309
+
310
  return {
311
  "answer": answer,
312
  "sources": sources,
 
314
  "total_candidates": len(candidates)
315
  }
316
 
317
+ except HTTPException:
318
+ raise
319
  except Exception as e:
320
+ logger.error(f"Error processing query: {str(e)}")
321
+ raise HTTPException(status_code=500, detail=f"Query processing error: {str(e)}")
322
 
323
  @app.get("/")
324
  async def root():
325
+ """Root endpoint - API info"""
326
  return {
327
+ "message": "🇳🇵 LexNepal AI API is running",
328
  "version": "1.0.0",
329
+ "description": "Advanced Legal Intelligence for Nepal Legal Code",
330
+ "endpoints": {
331
+ "docs": "/ (Swagger UI)",
332
+ "health": "/health (GET)",
333
+ "stats": "/stats (GET)",
334
+ "query": "/query (POST)"
335
+ },
336
+ "technology": "RAG with Hybrid Retrieval + Cross-Encoder Reranking",
337
+ "support": "https://huggingface.co/spaces/yamraj047/lexnepal-api"
338
  }
339
 
340
  if __name__ == "__main__":