Kartik Narang commited on
Commit
3cfeab7
·
0 Parent(s):

first clean commit

Browse files
Files changed (5) hide show
  1. app.py +594 -0
  2. requirements.txt +33 -0
  3. simple/ner.py +159 -0
  4. simple/rag.py +593 -0
  5. simple/summarizer.py +187 -0
app.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import asyncio
3
+ import uuid
4
+ from datetime import datetime, timedelta
5
+ from typing import Dict, Any, List, Optional
6
+ import logging
7
+ from contextlib import asynccontextmanager
8
+
9
+ from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks
10
+ from fastapi.middleware.cors import CORSMiddleware
11
+ from fastapi.responses import JSONResponse
12
+ import uvicorn
13
+
14
+ from motor.motor_asyncio import AsyncIOMotorClient
15
+ import pymongo
16
+ from pymongo import ASCENDING
17
+ import PyPDF2
18
+ import docx
19
+ import io
20
+ from PIL import Image
21
+ import pytesseract
22
+
23
+ # Import our models
24
+ from simple.rag import initialize_models, process_documents, create_embedding, chunk_text_hierarchical
25
+ from simple.ner import extract_legal_entities
26
+ from simple.summarizer import summarize_legal_document
27
+
28
+ # Configure logging
29
+ logging.basicConfig(level=logging.INFO)
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Global variables
33
+ mongodb_client: Optional[AsyncIOMotorClient] = None
34
+ db = None
35
+ cleanup_task = None
36
+
37
+ # Configuration
38
+ MONGODB_URI = os.getenv("MONGODB_URI", "mongodb+srv://username:password@cluster.mongodb.net/")
39
+ DATABASE_NAME = os.getenv("DATABASE_NAME", "legal_rag_system")
40
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "sentence-transformers/all-MiniLM-L6-v2")
41
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY", None)
42
+ SESSION_EXPIRE_HOURS = int(os.getenv("SESSION_EXPIRE_HOURS", "24"))
43
+
44
+ # Supported file types
45
+ SUPPORTED_EXTENSIONS = {'.pdf', '.txt', '.docx', '.doc'}
46
+ MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB
47
+
48
+ @asynccontextmanager
49
+ async def lifespan(app: FastAPI):
50
+ """Application lifespan manager"""
51
+ # Startup
52
+ await startup_event()
53
+ yield
54
+ # Shutdown
55
+ await shutdown_event()
56
+
57
+ app = FastAPI(
58
+ title="Legal Document Processor",
59
+ description="Process legal documents with NER, summarization, and embeddings",
60
+ version="1.0.0",
61
+ lifespan=lifespan
62
+ )
63
+
64
+ # CORS middleware
65
+ app.add_middleware(
66
+ CORSMiddleware,
67
+ allow_origins=["*"], # Configure this properly for production
68
+ allow_credentials=True,
69
+ allow_methods=["*"],
70
+ allow_headers=["*"],
71
+ )
72
+
73
+ async def startup_event():
74
+ """Initialize services on startup"""
75
+ global mongodb_client, db, cleanup_task
76
+
77
+ try:
78
+ logger.info("🚀 Starting up Legal Document Processor...")
79
+
80
+ # Initialize MongoDB
81
+ logger.info("📊 Connecting to MongoDB...")
82
+ mongodb_client = AsyncIOMotorClient(MONGODB_URI)
83
+ db = mongodb_client[DATABASE_NAME]
84
+
85
+ # Test connection
86
+ await mongodb_client.admin.command('ping')
87
+ logger.info("✅ MongoDB connected successfully")
88
+
89
+ # Create indexes
90
+ await create_indexes()
91
+
92
+ # Initialize ML models
93
+ logger.info("🤖 Loading ML models...")
94
+ initialize_models(HF_MODEL_ID, GROQ_API_KEY)
95
+ logger.info("✅ Models loaded successfully")
96
+
97
+ # Start cleanup task
98
+ cleanup_task = asyncio.create_task(periodic_cleanup())
99
+ logger.info("🧹 Cleanup task started")
100
+
101
+ logger.info("🎉 Startup completed successfully!")
102
+
103
+ except Exception as e:
104
+ logger.error(f"❌ Startup failed: {str(e)}")
105
+ raise
106
+
107
+ async def shutdown_event():
108
+ """Cleanup on shutdown"""
109
+ global mongodb_client, cleanup_task
110
+
111
+ logger.info("🛑 Shutting down...")
112
+
113
+ if cleanup_task:
114
+ cleanup_task.cancel()
115
+ try:
116
+ await cleanup_task
117
+ except asyncio.CancelledError:
118
+ pass
119
+
120
+ if mongodb_client:
121
+ mongodb_client.close()
122
+
123
+ logger.info("✅ Shutdown completed")
124
+
125
+ async def create_indexes():
126
+ """Create MongoDB indexes for optimal performance"""
127
+ try:
128
+ # Sessions collection indexes
129
+ await db.sessions.create_index([("session_id", ASCENDING)], unique=True)
130
+ await db.sessions.create_index([("created_at", ASCENDING)], expireAfterSeconds=SESSION_EXPIRE_HOURS * 3600)
131
+ await db.sessions.create_index([("status", ASCENDING)])
132
+
133
+ # Chunks collection indexes
134
+ await db.chunks.create_index([("session_id", ASCENDING)])
135
+ await db.chunks.create_index([("chunk_id", ASCENDING)])
136
+ await db.chunks.create_index([("created_at", ASCENDING)], expireAfterSeconds=SESSION_EXPIRE_HOURS * 3600)
137
+
138
+ # NER results collection indexes
139
+ await db.ner_results.create_index([("session_id", ASCENDING)])
140
+ await db.ner_results.create_index([("created_at", ASCENDING)], expireAfterSeconds=SESSION_EXPIRE_HOURS * 3600)
141
+
142
+ # Summaries collection indexes
143
+ await db.summaries.create_index([("session_id", ASCENDING)])
144
+ await db.summaries.create_index([("created_at", ASCENDING)], expireAfterSeconds=SESSION_EXPIRE_HOURS * 3600)
145
+
146
+ logger.info("📊 Database indexes created successfully")
147
+
148
+ except Exception as e:
149
+ logger.error(f"❌ Failed to create indexes: {str(e)}")
150
+
151
+ async def periodic_cleanup():
152
+ """Periodically clean up expired sessions"""
153
+ while True:
154
+ try:
155
+ await asyncio.sleep(3600) # Run every hour
156
+ await cleanup_expired_sessions()
157
+ except asyncio.CancelledError:
158
+ break
159
+ except Exception as e:
160
+ logger.error(f"❌ Cleanup task error: {str(e)}")
161
+
162
+ async def cleanup_expired_sessions():
163
+ """Clean up expired sessions from MongoDB"""
164
+ try:
165
+ cutoff_time = datetime.utcnow() - timedelta(hours=SESSION_EXPIRE_HOURS)
166
+
167
+ # Count expired sessions
168
+ expired_count = await db.sessions.count_documents({
169
+ "created_at": {"$lt": cutoff_time}
170
+ })
171
+
172
+ if expired_count > 0:
173
+ # Delete expired sessions and related data
174
+ await db.sessions.delete_many({"created_at": {"$lt": cutoff_time}})
175
+ await db.chunks.delete_many({"created_at": {"$lt": cutoff_time}})
176
+ await db.ner_results.delete_many({"created_at": {"$lt": cutoff_time}})
177
+ await db.summaries.delete_many({"created_at": {"$lt": cutoff_time}})
178
+
179
+ logger.info(f"🧹 Cleaned up {expired_count} expired sessions")
180
+
181
+ except Exception as e:
182
+ logger.error(f"❌ Cleanup failed: {str(e)}")
183
+
184
+ def extract_text_from_file(file_content: bytes, filename: str) -> str:
185
+ """Extract text from various file formats"""
186
+ file_ext = os.path.splitext(filename.lower())[1]
187
+
188
+ try:
189
+ if file_ext == '.pdf':
190
+ return extract_text_from_pdf(file_content)
191
+ elif file_ext == '.txt':
192
+ return file_content.decode('utf-8', errors='ignore')
193
+ elif file_ext in ['.docx', '.doc']:
194
+ return extract_text_from_docx(file_content)
195
+ else:
196
+ raise ValueError(f"Unsupported file type: {file_ext}")
197
+ except Exception as e:
198
+ logger.error(f"❌ Text extraction failed for {filename}: {str(e)}")
199
+ raise
200
+
201
+ def extract_text_from_pdf(file_content: bytes) -> str:
202
+ """Extract text from PDF file"""
203
+ try:
204
+ pdf_file = io.BytesIO(file_content)
205
+ pdf_reader = PyPDF2.PdfReader(pdf_file)
206
+ text = ""
207
+
208
+ for page in pdf_reader.pages:
209
+ text += page.extract_text() + "\n"
210
+
211
+ if not text.strip():
212
+ # Try OCR if no text extracted
213
+ logger.info("📷 No text found in PDF, attempting OCR...")
214
+ # This would require additional setup for OCR
215
+ text = "OCR extraction not implemented yet"
216
+
217
+ return text
218
+ except Exception as e:
219
+ logger.error(f"❌ PDF extraction failed: {str(e)}")
220
+ raise
221
+
222
+ def extract_text_from_docx(file_content: bytes) -> str:
223
+ """Extract text from DOCX file"""
224
+ try:
225
+ doc_file = io.BytesIO(file_content)
226
+ doc = docx.Document(doc_file)
227
+ text = ""
228
+
229
+ for paragraph in doc.paragraphs:
230
+ text += paragraph.text + "\n"
231
+
232
+ return text
233
+ except Exception as e:
234
+ logger.error(f"❌ DOCX extraction failed: {str(e)}")
235
+ raise
236
+
237
+ async def process_document_pipeline(
238
+ session_id: str,
239
+ text: str,
240
+ filename: str,
241
+ background_tasks: BackgroundTasks
242
+ ):
243
+ """Process document through the complete pipeline"""
244
+ try:
245
+ logger.info(f"🔄 Starting processing pipeline for session {session_id}")
246
+
247
+ # Update session status
248
+ await db.sessions.update_one(
249
+ {"session_id": session_id},
250
+ {"$set": {"status": "processing", "updated_at": datetime.utcnow()}}
251
+ )
252
+
253
+ # Step 1: NER Processing
254
+ logger.info(f"🔍 Running NER for session {session_id}")
255
+ ner_results = extract_legal_entities(text)
256
+
257
+ # Store NER results
258
+ await db.ner_results.insert_one({
259
+ "session_id": session_id,
260
+ "filename": filename,
261
+ "results": ner_results,
262
+ "created_at": datetime.utcnow()
263
+ })
264
+
265
+ # Step 2: Summarization
266
+ logger.info(f"📄 Running summarization for session {session_id}")
267
+ summary_results = summarize_legal_document(
268
+ text,
269
+ max_sentences=5,
270
+ groq_api_key=GROQ_API_KEY
271
+ )
272
+
273
+ # Store summary results
274
+ await db.summaries.insert_one({
275
+ "session_id": session_id,
276
+ "filename": filename,
277
+ "results": summary_results,
278
+ "created_at": datetime.utcnow()
279
+ })
280
+
281
+ # Step 3: Chunking and Embedding
282
+ logger.info(f"🧩 Creating chunks and embeddings for session {session_id}")
283
+ chunks = chunk_text_hierarchical(text, filename)
284
+
285
+ # Create embeddings and store chunks
286
+ chunks_to_store = []
287
+ for chunk in chunks:
288
+ # Create embedding
289
+ embedding = create_embedding(chunk['text'])
290
+
291
+ chunk_doc = {
292
+ "session_id": session_id,
293
+ "chunk_id": chunk['id'],
294
+ "text": chunk['text'],
295
+ "title": chunk['title'],
296
+ "section_type": chunk['section_type'],
297
+ "importance_score": chunk['importance_score'],
298
+ "entities": chunk['entities'],
299
+ "embedding": embedding.tolist(), # Convert numpy array to list
300
+ "created_at": datetime.utcnow()
301
+ }
302
+ chunks_to_store.append(chunk_doc)
303
+
304
+ # Batch insert chunks
305
+ if chunks_to_store:
306
+ await db.chunks.insert_many(chunks_to_store)
307
+
308
+ # Update session as completed
309
+ await db.sessions.update_one(
310
+ {"session_id": session_id},
311
+ {
312
+ "$set": {
313
+ "status": "completed",
314
+ "updated_at": datetime.utcnow(),
315
+ "chunk_count": len(chunks_to_store),
316
+ "processing_completed_at": datetime.utcnow()
317
+ }
318
+ }
319
+ )
320
+
321
+ logger.info(f"✅ Processing completed for session {session_id}")
322
+
323
+ except Exception as e:
324
+ logger.error(f"❌ Processing failed for session {session_id}: {str(e)}")
325
+
326
+ # Update session with error
327
+ await db.sessions.update_one(
328
+ {"session_id": session_id},
329
+ {
330
+ "$set": {
331
+ "status": "failed",
332
+ "error": str(e),
333
+ "updated_at": datetime.utcnow()
334
+ }
335
+ }
336
+ )
337
+
338
+ @app.post("/upload")
339
+ async def upload_document(
340
+ background_tasks: BackgroundTasks,
341
+ file: UploadFile = File(...)
342
+ ):
343
+ """Upload and process a legal document"""
344
+ try:
345
+ # Validate file
346
+ if not file.filename:
347
+ raise HTTPException(status_code=400, detail="No file provided")
348
+
349
+ file_ext = os.path.splitext(file.filename.lower())[1]
350
+ if file_ext not in SUPPORTED_EXTENSIONS:
351
+ raise HTTPException(
352
+ status_code=400,
353
+ detail=f"Unsupported file type. Supported: {', '.join(SUPPORTED_EXTENSIONS)}"
354
+ )
355
+
356
+ # Check file size
357
+ file_content = await file.read()
358
+ if len(file_content) > MAX_FILE_SIZE:
359
+ raise HTTPException(
360
+ status_code=400,
361
+ detail=f"File too large. Maximum size: {MAX_FILE_SIZE // (1024*1024)}MB"
362
+ )
363
+
364
+ # Generate session ID
365
+ session_id = str(uuid.uuid4())
366
+
367
+ # Extract text
368
+ logger.info(f"📄 Extracting text from {file.filename}")
369
+ text = extract_text_from_file(file_content, file.filename)
370
+
371
+ if not text.strip():
372
+ raise HTTPException(status_code=400, detail="No text could be extracted from the file")
373
+
374
+ # Create session record
375
+ session_doc = {
376
+ "session_id": session_id,
377
+ "filename": file.filename,
378
+ "file_size": len(file_content),
379
+ "text_length": len(text),
380
+ "word_count": len(text.split()),
381
+ "status": "uploaded",
382
+ "created_at": datetime.utcnow(),
383
+ "updated_at": datetime.utcnow()
384
+ }
385
+
386
+ await db.sessions.insert_one(session_doc)
387
+
388
+ # Start background processing
389
+ background_tasks.add_task(
390
+ process_document_pipeline,
391
+ session_id,
392
+ text,
393
+ file.filename,
394
+ background_tasks
395
+ )
396
+
397
+ logger.info(f"✅ Document uploaded successfully. Session ID: {session_id}")
398
+
399
+ return JSONResponse(
400
+ status_code=200,
401
+ content={
402
+ "success": True,
403
+ "session_id": session_id,
404
+ "filename": file.filename,
405
+ "file_size": len(file_content),
406
+ "text_length": len(text),
407
+ "word_count": len(text.split()),
408
+ "status": "processing",
409
+ "message": "Document uploaded successfully. Processing started."
410
+ }
411
+ )
412
+
413
+ except HTTPException:
414
+ raise
415
+ except Exception as e:
416
+ logger.error(f"❌ Upload failed: {str(e)}")
417
+ raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}")
418
+
419
+ @app.get("/status/{session_id}")
420
+ async def get_session_status(session_id: str):
421
+ """Get the processing status of a session"""
422
+ try:
423
+ session = await db.sessions.find_one({"session_id": session_id})
424
+
425
+ if not session:
426
+ raise HTTPException(status_code=404, detail="Session not found")
427
+
428
+ # Convert ObjectId to string for JSON serialization
429
+ session["_id"] = str(session["_id"])
430
+
431
+ # Add processing progress info
432
+ if session["status"] == "completed":
433
+ # Get additional info
434
+ ner_result = await db.ner_results.find_one({"session_id": session_id})
435
+ summary_result = await db.summaries.find_one({"session_id": session_id})
436
+ chunk_count = await db.chunks.count_documents({"session_id": session_id})
437
+
438
+ session["ner_entities"] = ner_result["results"]["total_entities"] if ner_result else 0
439
+ session["summary_available"] = bool(summary_result)
440
+ session["chunk_count"] = chunk_count
441
+
442
+ return JSONResponse(
443
+ status_code=200,
444
+ content={
445
+ "success": True,
446
+ "session": session
447
+ }
448
+ )
449
+
450
+ except HTTPException:
451
+ raise
452
+ except Exception as e:
453
+ logger.error(f"❌ Status check failed: {str(e)}")
454
+ raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}")
455
+
456
+ @app.get("/results/{session_id}")
457
+ async def get_processing_results(session_id: str):
458
+ """Get all processing results for a session"""
459
+ try:
460
+ # Check if session exists and is completed
461
+ session = await db.sessions.find_one({"session_id": session_id})
462
+ if not session:
463
+ raise HTTPException(status_code=404, detail="Session not found")
464
+
465
+ if session["status"] != "completed":
466
+ return JSONResponse(
467
+ status_code=202,
468
+ content={
469
+ "success": False,
470
+ "message": f"Processing not completed. Current status: {session['status']}"
471
+ }
472
+ )
473
+
474
+ # Get NER results
475
+ ner_result = await db.ner_results.find_one({"session_id": session_id})
476
+
477
+ # Get summary results
478
+ summary_result = await db.summaries.find_one({"session_id": session_id})
479
+
480
+ # Get chunk metadata (not full text)
481
+ chunks_cursor = db.chunks.find(
482
+ {"session_id": session_id},
483
+ {"text": 0, "embedding": 0} # Exclude large fields
484
+ )
485
+ chunks_metadata = await chunks_cursor.to_list(length=None)
486
+
487
+ # Clean up ObjectIds
488
+ for chunk in chunks_metadata:
489
+ chunk["_id"] = str(chunk["_id"])
490
+
491
+ return JSONResponse(
492
+ status_code=200,
493
+ content={
494
+ "success": True,
495
+ "session_id": session_id,
496
+ "filename": session["filename"],
497
+ "ner_results": ner_result["results"] if ner_result else None,
498
+ "summary_results": summary_result["results"] if summary_result else None,
499
+ "chunks_metadata": {
500
+ "total_chunks": len(chunks_metadata),
501
+ "chunks": chunks_metadata[:10] # Return first 10 chunks metadata
502
+ },
503
+ "processing_completed_at": session.get("processing_completed_at")
504
+ }
505
+ )
506
+
507
+ except HTTPException:
508
+ raise
509
+ except Exception as e:
510
+ logger.error(f"❌ Results retrieval failed: {str(e)}")
511
+ raise HTTPException(status_code=500, detail=f"Results retrieval failed: {str(e)}")
512
+
513
+ @app.get("/health")
514
+ async def health_check():
515
+ """Health check endpoint"""
516
+ try:
517
+ # Test MongoDB connection
518
+ await mongodb_client.admin.command('ping')
519
+
520
+ return JSONResponse(
521
+ status_code=200,
522
+ content={
523
+ "status": "healthy",
524
+ "timestamp": datetime.utcnow().isoformat(),
525
+ "services": {
526
+ "mongodb": "connected",
527
+ "ml_models": "loaded"
528
+ }
529
+ }
530
+ )
531
+ except Exception as e:
532
+ logger.error(f"❌ Health check failed: {str(e)}")
533
+ return JSONResponse(
534
+ status_code=503,
535
+ content={
536
+ "status": "unhealthy",
537
+ "error": str(e),
538
+ "timestamp": datetime.utcnow().isoformat()
539
+ }
540
+ )
541
+
542
+ @app.delete("/session/{session_id}")
543
+ async def delete_session(session_id: str):
544
+ """Manually delete a session and all related data"""
545
+ try:
546
+ # Delete from all collections
547
+ session_result = await db.sessions.delete_one({"session_id": session_id})
548
+ await db.chunks.delete_many({"session_id": session_id})
549
+ await db.ner_results.delete_many({"session_id": session_id})
550
+ await db.summaries.delete_many({"session_id": session_id})
551
+
552
+ if session_result.deleted_count == 0:
553
+ raise HTTPException(status_code=404, detail="Session not found")
554
+
555
+ return JSONResponse(
556
+ status_code=200,
557
+ content={
558
+ "success": True,
559
+ "message": f"Session {session_id} deleted successfully"
560
+ }
561
+ )
562
+
563
+ except HTTPException:
564
+ raise
565
+ except Exception as e:
566
+ logger.error(f"❌ Session deletion failed: {str(e)}")
567
+ raise HTTPException(status_code=500, detail=f"Session deletion failed: {str(e)}")
568
+
569
+ @app.get("/")
570
+ async def root():
571
+ """Root endpoint with API information"""
572
+ return {
573
+ "service": "Legal Document Processor",
574
+ "version": "1.0.0",
575
+ "status": "running",
576
+ "endpoints": {
577
+ "upload": "POST /upload - Upload a legal document for processing",
578
+ "status": "GET /status/{session_id} - Check processing status",
579
+ "results": "GET /results/{session_id} - Get processing results",
580
+ "health": "GET /health - Health check",
581
+ "delete": "DELETE /session/{session_id} - Delete a session"
582
+ },
583
+ "supported_formats": list(SUPPORTED_EXTENSIONS)
584
+ }
585
+
586
+ if __name__ == "__main__":
587
+ port = int(os.getenv("PORT", 7860))
588
+ uvicorn.run(
589
+ "app:app",
590
+ host="0.0.0.0",
591
+ port=port,
592
+ reload=False,
593
+ access_log=True
594
+ )
requirements.txt ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face Spaces requirements
2
+ gradio==4.44.0
3
+ requests==2.31.0
4
+ fastapi==0.115.6
5
+ uvicorn==0.32.1
6
+ python-multipart==0.0.9 # ✅ needed for FastAPI file uploads
7
+
8
+ # Core ML/NLP
9
+ torch==2.2.2
10
+ transformers==4.44.2
11
+ sentence-transformers==2.2.2
12
+ spacy==3.8.2
13
+ scikit-learn==1.5.2
14
+ numpy==1.26.4
15
+ pandas==2.2.3
16
+ nltk==3.9.1
17
+
18
+ # Retrieval / Search
19
+ faiss-cpu==1.7.4
20
+ rank-bm25==0.2.2
21
+
22
+ # File parsing (PDF, DOCX, OCR)
23
+ PyPDF2==3.0.1
24
+ pdfplumber==0.11.4
25
+ python-docx==1.1.2
26
+ pytesseract==0.3.13
27
+ easyocr==1.7.1
28
+ pdf2image==1.16.3
29
+ opencv-python==4.10.0.84
30
+ Pillow==10.4.0
31
+
32
+ # API clients
33
+ groq==0.13.0
simple/ner.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ from huggingface_hub import snapshot_download
3
+ from typing import Dict, Any
4
+
5
+ def extract_legal_entities(text, model_id=None, hf_token=None):
6
+ """
7
+ Extract named entities from legal text
8
+
9
+ Args:
10
+ text: Input text to process
11
+ model_id: Optional Hugging Face model ID (defaults to en_core_web_sm)
12
+ hf_token: Optional Hugging Face token
13
+
14
+ Returns:
15
+ Dictionary with entities and counts
16
+ """
17
+ if not text or not text.strip():
18
+ return {
19
+ "error": "Empty text provided",
20
+ "entities": [],
21
+ "entity_counts": {},
22
+ "total_entities": 0
23
+ }
24
+
25
+ # Load model
26
+ nlp = _load_ner_model(model_id, hf_token)
27
+ if not nlp:
28
+ return {
29
+ "error": "Failed to load NER model",
30
+ "entities": [],
31
+ "entity_counts": {},
32
+ "total_entities": 0
33
+ }
34
+
35
+ try:
36
+ # Process text (handle large texts by chunking)
37
+ if len(text) > 4000000:
38
+ return _process_large_text(text, nlp)
39
+
40
+ doc = nlp(text)
41
+
42
+ entities = []
43
+ entity_counts = {}
44
+
45
+ for ent in doc.ents:
46
+ processed_entities = _process_entity(ent)
47
+
48
+ for entity_text, entity_label in processed_entities:
49
+ entity_info = {
50
+ "text": entity_text,
51
+ "label": entity_label,
52
+ "start": ent.start_char,
53
+ "end": ent.end_char
54
+ }
55
+ entities.append(entity_info)
56
+
57
+ if entity_label not in entity_counts:
58
+ entity_counts[entity_label] = []
59
+ entity_counts[entity_label].append(entity_text)
60
+
61
+ # Process counts
62
+ for label in entity_counts:
63
+ unique_entities = list(set(entity_counts[label]))
64
+ entity_counts[label] = {
65
+ "entities": unique_entities,
66
+ "count": len(unique_entities)
67
+ }
68
+
69
+ return {
70
+ "entities": entities,
71
+ "entity_counts": entity_counts,
72
+ "total_entities": len(entities),
73
+ "unique_labels": list(entity_counts.keys())
74
+ }
75
+
76
+ except Exception as e:
77
+ return {
78
+ "error": str(e),
79
+ "entities": [],
80
+ "entity_counts": {},
81
+ "total_entities": 0
82
+ }
83
+
84
+ def _load_ner_model(model_id, hf_token):
85
+ """Load spaCy NER model"""
86
+ if not model_id:
87
+ model_id = 'en_core_web_sm'
88
+
89
+ try:
90
+ # Try loading from Hugging Face
91
+ if model_id != 'en_core_web_sm':
92
+ local_dir = snapshot_download(
93
+ repo_id=model_id,
94
+ token=hf_token if hf_token else None
95
+ )
96
+ return spacy.load(local_dir)
97
+ else:
98
+ # Load standard model
99
+ return spacy.load("en_core_web_sm")
100
+
101
+ except Exception:
102
+ # Fallback to standard English model
103
+ try:
104
+ return spacy.load("en_core_web_sm")
105
+ except Exception:
106
+ return None
107
+
108
+ def _process_large_text(text, nlp, chunk_size=3000000):
109
+ """Process large text by chunking"""
110
+ chunks = [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
111
+ all_entities = []
112
+ all_entity_counts = {}
113
+
114
+ for i, chunk in enumerate(chunks):
115
+ try:
116
+ doc = nlp(chunk)
117
+
118
+ for ent in doc.ents:
119
+ processed_entities = _process_entity(ent)
120
+
121
+ for entity_text, entity_label in processed_entities:
122
+ entity_info = {
123
+ "text": entity_text,
124
+ "label": entity_label,
125
+ "start": ent.start_char + (i * chunk_size),
126
+ "end": ent.end_char + (i * chunk_size)
127
+ }
128
+ all_entities.append(entity_info)
129
+
130
+ if entity_label not in all_entity_counts:
131
+ all_entity_counts[entity_label] = []
132
+ all_entity_counts[entity_label].append(entity_text)
133
+
134
+ except Exception:
135
+ continue
136
+
137
+ # Process counts
138
+ for label in all_entity_counts:
139
+ unique_entities = list(set(all_entity_counts[label]))
140
+ all_entity_counts[label] = {
141
+ "entities": unique_entities,
142
+ "count": len(unique_entities)
143
+ }
144
+
145
+ return {
146
+ "entities": all_entities,
147
+ "entity_counts": all_entity_counts,
148
+ "total_entities": len(all_entities),
149
+ "unique_labels": list(all_entity_counts.keys()),
150
+ "processed_in_chunks": True,
151
+ "num_chunks": len(chunks)
152
+ }
153
+
154
+ def _process_entity(ent):
155
+ """Process individual entity (handle special cases like 'X and Y')"""
156
+ if ent.label_ in ["PRECEDENT", "ORG"] and " and " in ent.text:
157
+ parts = ent.text.split(" and ")
158
+ return [(p.strip(), "ORG") for p in parts]
159
+ return [(ent.text, ent.label_)]
simple/rag.py ADDED
@@ -0,0 +1,593 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from transformers import AutoTokenizer, AutoModel
4
+ from typing import List, Dict, Any, Tuple, Optional
5
+ import faiss
6
+ import hashlib
7
+ from tqdm import tqdm
8
+ from groq import Groq
9
+ import re
10
+ import nltk
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+ import networkx as nx
13
+ from collections import defaultdict
14
+ import spacy
15
+ from rank_bm25 import BM25Okapi
16
+
17
+ # Global variables for models
18
+ MODEL = None
19
+ TOKENIZER = None
20
+ GROQ_CLIENT = None
21
+ NLP_MODEL = None
22
+ DEVICE = None
23
+
24
+ # Global indices
25
+ DENSE_INDEX = None
26
+ BM25_INDEX = None
27
+ CONCEPT_GRAPH = None
28
+ TOKEN_TO_CHUNKS = None
29
+ CHUNKS_DATA = []
30
+
31
+ # Legal knowledge base
32
+ LEGAL_CONCEPTS = {
33
+ 'liability': ['negligence', 'strict liability', 'vicarious liability', 'product liability'],
34
+ 'contract': ['breach', 'consideration', 'offer', 'acceptance', 'damages', 'specific performance'],
35
+ 'criminal': ['mens rea', 'actus reus', 'intent', 'malice', 'premeditation'],
36
+ 'procedure': ['jurisdiction', 'standing', 'statute of limitations', 'res judicata'],
37
+ 'evidence': ['hearsay', 'relevance', 'privilege', 'burden of proof', 'admissibility'],
38
+ 'constitutional': ['due process', 'equal protection', 'free speech', 'search and seizure']
39
+ }
40
+
41
+ QUERY_PATTERNS = {
42
+ 'precedent': ['case', 'precedent', 'ruling', 'held', 'decision'],
43
+ 'statute_interpretation': ['statute', 'section', 'interpretation', 'meaning', 'definition'],
44
+ 'factual': ['what happened', 'facts', 'circumstances', 'events'],
45
+ 'procedure': ['how to', 'procedure', 'process', 'filing', 'requirements']
46
+ }
47
+
48
+ def initialize_models(model_id: str, groq_api_key: str = None):
49
+ """Initialize all models and components"""
50
+ global MODEL, TOKENIZER, GROQ_CLIENT, NLP_MODEL, DEVICE
51
+
52
+ try:
53
+ nltk.download('punkt', quiet=True)
54
+ nltk.download('stopwords', quiet=True)
55
+ except:
56
+ pass
57
+
58
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
59
+ print(f"Using device: {DEVICE}")
60
+
61
+ print(f"Loading model: {model_id}")
62
+ TOKENIZER = AutoTokenizer.from_pretrained(model_id)
63
+ MODEL = AutoModel.from_pretrained(model_id).to(DEVICE)
64
+ MODEL.eval()
65
+
66
+ if groq_api_key:
67
+ GROQ_CLIENT = Groq(api_key=groq_api_key)
68
+
69
+ try:
70
+ NLP_MODEL = spacy.load("en_core_web_sm")
71
+ except:
72
+ print("SpaCy model not found, using basic NER")
73
+ NLP_MODEL = None
74
+
75
+ def create_embedding(text: str) -> np.ndarray:
76
+ """Create dense embedding for text"""
77
+ inputs = TOKENIZER(text, padding=True, truncation=True,
78
+ max_length=512, return_tensors='pt').to(DEVICE)
79
+
80
+ with torch.no_grad():
81
+ outputs = MODEL(**inputs)
82
+ attention_mask = inputs['attention_mask']
83
+ token_embeddings = outputs.last_hidden_state
84
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
85
+ embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
86
+
87
+ # Normalize embeddings
88
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
89
+
90
+ return embeddings.cpu().numpy()[0]
91
+
92
+ def extract_legal_entities(text: str) -> List[Dict[str, Any]]:
93
+ """Extract legal entities from text"""
94
+ entities = []
95
+
96
+ if NLP_MODEL:
97
+ doc = NLP_MODEL(text[:5000]) # Limit for performance
98
+ for ent in doc.ents:
99
+ if ent.label_ in ['PERSON', 'ORG', 'LAW', 'GPE']:
100
+ entities.append({
101
+ 'text': ent.text,
102
+ 'type': ent.label_,
103
+ 'importance': 1.0
104
+ })
105
+
106
+ # Legal citations
107
+ citation_pattern = r'\b\d+\s+[A-Z][a-z]+\.?\s+\d+\b'
108
+ for match in re.finditer(citation_pattern, text):
109
+ entities.append({
110
+ 'text': match.group(),
111
+ 'type': 'case_citation',
112
+ 'importance': 2.0
113
+ })
114
+
115
+ # Statute references
116
+ statute_pattern = r'§\s*\d+[\.\d]*|\bSection\s+\d+'
117
+ for match in re.finditer(statute_pattern, text):
118
+ entities.append({
119
+ 'text': match.group(),
120
+ 'type': 'statute',
121
+ 'importance': 1.5
122
+ })
123
+
124
+ return entities
125
+
126
+ def analyze_query(query: str) -> Dict[str, Any]:
127
+ """Analyze query to understand intent"""
128
+ query_lower = query.lower()
129
+
130
+ # Classify query type
131
+ query_type = 'general'
132
+ for qtype, patterns in QUERY_PATTERNS.items():
133
+ if any(pattern in query_lower for pattern in patterns):
134
+ query_type = qtype
135
+ break
136
+
137
+ # Extract entities
138
+ entities = extract_legal_entities(query)
139
+
140
+ # Extract key concepts
141
+ key_concepts = []
142
+ for concept_category, concepts in LEGAL_CONCEPTS.items():
143
+ for concept in concepts:
144
+ if concept in query_lower:
145
+ key_concepts.append(concept)
146
+
147
+ # Generate expanded queries
148
+ expanded_queries = [query]
149
+
150
+ # Concept expansion
151
+ if key_concepts:
152
+ expanded_queries.append(f"{query} {' '.join(key_concepts[:3])}")
153
+
154
+ # Type-based expansion
155
+ if query_type == 'precedent':
156
+ expanded_queries.append(f"legal precedent case law {query}")
157
+ elif query_type == 'statute_interpretation':
158
+ expanded_queries.append(f"statutory interpretation meaning {query}")
159
+
160
+ # HyDE - Hypothetical document generation
161
+ if GROQ_CLIENT:
162
+ hyde_doc = generate_hypothetical_document(query)
163
+ if hyde_doc:
164
+ expanded_queries.append(hyde_doc)
165
+
166
+ return {
167
+ 'original_query': query,
168
+ 'query_type': query_type,
169
+ 'entities': entities,
170
+ 'key_concepts': key_concepts,
171
+ 'expanded_queries': expanded_queries[:4] # Limit to 4 queries
172
+ }
173
+
174
+ def generate_hypothetical_document(query: str) -> Optional[str]:
175
+ """Generate hypothetical answer document (HyDE technique)"""
176
+ if not GROQ_CLIENT:
177
+ return None
178
+
179
+ try:
180
+ prompt = f"""Generate a brief hypothetical legal document excerpt that would answer this question: {query}
181
+
182
+ Write it as if it's from an actual legal case or statute. Be specific and use legal language.
183
+ Keep it under 100 words."""
184
+
185
+ response = GROQ_CLIENT.chat.completions.create(
186
+ messages=[
187
+ {"role": "system", "content": "You are a legal expert generating hypothetical legal text."},
188
+ {"role": "user", "content": prompt}
189
+ ],
190
+ model="llama-3.1-8b-instant",
191
+ temperature=0.3,
192
+ max_tokens=150
193
+ )
194
+
195
+ return response.choices[0].message.content
196
+ except:
197
+ return None
198
+
199
+ def chunk_text_hierarchical(text: str, title: str = "") -> List[Dict[str, Any]]:
200
+ """Create hierarchical chunks with legal structure awareness"""
201
+ chunks = []
202
+
203
+ # Clean text
204
+ text = re.sub(r'\s+', ' ', text)
205
+
206
+ # Identify legal sections
207
+ section_patterns = [
208
+ (r'(?i)\bFACTS?\b[:\s]', 'facts'),
209
+ (r'(?i)\bHOLDING\b[:\s]', 'holding'),
210
+ (r'(?i)\bREASONING\b[:\s]', 'reasoning'),
211
+ (r'(?i)\bDISSENT\b[:\s]', 'dissent'),
212
+ (r'(?i)\bCONCLUSION\b[:\s]', 'conclusion')
213
+ ]
214
+
215
+ sections = []
216
+ for pattern, section_type in section_patterns:
217
+ matches = list(re.finditer(pattern, text))
218
+ for match in matches:
219
+ sections.append((match.start(), section_type))
220
+
221
+ sections.sort(key=lambda x: x[0])
222
+
223
+ # Split into sentences
224
+ import nltk
225
+ try:
226
+ sentences = nltk.sent_tokenize(text)
227
+ except:
228
+ sentences = text.split('. ')
229
+
230
+ # Create chunks
231
+ current_section = 'introduction'
232
+ section_sentences = []
233
+ chunk_size = 500 # words
234
+
235
+ for sent in sentences:
236
+ # Check section type
237
+ sent_pos = text.find(sent)
238
+ for pos, stype in sections:
239
+ if sent_pos >= pos:
240
+ current_section = stype
241
+
242
+ section_sentences.append(sent)
243
+
244
+ # Create chunk when we have enough content
245
+ chunk_text = ' '.join(section_sentences)
246
+ if len(chunk_text.split()) >= chunk_size or len(section_sentences) >= 10:
247
+ chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12]
248
+
249
+ # Calculate importance
250
+ importance = 1.0
251
+ section_weights = {
252
+ 'holding': 2.0, 'conclusion': 1.8, 'reasoning': 1.5,
253
+ 'facts': 1.2, 'dissent': 0.8
254
+ }
255
+ importance *= section_weights.get(current_section, 1.0)
256
+
257
+ # Entity importance
258
+ entities = extract_legal_entities(chunk_text)
259
+ if entities:
260
+ entity_score = sum(e['importance'] for e in entities) / len(entities)
261
+ importance *= (1 + entity_score * 0.5)
262
+
263
+ chunks.append({
264
+ 'id': chunk_id,
265
+ 'text': chunk_text,
266
+ 'title': title,
267
+ 'section_type': current_section,
268
+ 'importance_score': importance,
269
+ 'entities': entities,
270
+ 'embedding': None # Will be filled during indexing
271
+ })
272
+
273
+ section_sentences = []
274
+
275
+ # Add remaining sentences
276
+ if section_sentences:
277
+ chunk_text = ' '.join(section_sentences)
278
+ chunk_id = hashlib.md5(f"{title}_{len(chunks)}_{chunk_text[:50]}".encode()).hexdigest()[:12]
279
+ chunks.append({
280
+ 'id': chunk_id,
281
+ 'text': chunk_text,
282
+ 'title': title,
283
+ 'section_type': current_section,
284
+ 'importance_score': 1.0,
285
+ 'entities': extract_legal_entities(chunk_text),
286
+ 'embedding': None
287
+ })
288
+
289
+ return chunks
290
+
291
+ def build_all_indices(chunks: List[Dict[str, Any]]):
292
+ """Build all retrieval indices"""
293
+ global DENSE_INDEX, BM25_INDEX, CONCEPT_GRAPH, TOKEN_TO_CHUNKS, CHUNKS_DATA
294
+
295
+ CHUNKS_DATA = chunks
296
+ print(f"Building indices for {len(chunks)} chunks...")
297
+
298
+ # 1. Dense embeddings + FAISS index
299
+ print("Building FAISS index...")
300
+ embeddings = []
301
+ for chunk in tqdm(chunks, desc="Creating embeddings"):
302
+ embedding = create_embedding(chunk['text'])
303
+ chunk['embedding'] = embedding
304
+ embeddings.append(embedding)
305
+
306
+ embeddings_matrix = np.vstack(embeddings)
307
+ DENSE_INDEX = faiss.IndexFlatIP(embeddings_matrix.shape[1]) # Inner product for normalized vectors
308
+ DENSE_INDEX.add(embeddings_matrix.astype('float32'))
309
+
310
+ # 2. BM25 index for sparse retrieval
311
+ print("Building BM25 index...")
312
+ tokenized_corpus = [chunk['text'].lower().split() for chunk in chunks]
313
+ BM25_INDEX = BM25Okapi(tokenized_corpus)
314
+
315
+ # 3. ColBERT-style token index
316
+ print("Building ColBERT token index...")
317
+ TOKEN_TO_CHUNKS = defaultdict(set)
318
+ for i, chunk in enumerate(chunks):
319
+ # Simple tokenization for token-level matching
320
+ tokens = chunk['text'].lower().split()
321
+ for token in tokens:
322
+ TOKEN_TO_CHUNKS[token].add(i)
323
+
324
+ # 4. Legal concept graph
325
+ print("Building legal concept graph...")
326
+ CONCEPT_GRAPH = nx.Graph()
327
+
328
+ for i, chunk in enumerate(chunks):
329
+ CONCEPT_GRAPH.add_node(i, text=chunk['text'][:200], importance=chunk['importance_score'])
330
+
331
+ # Add edges between chunks with shared entities
332
+ for j, other_chunk in enumerate(chunks[i+1:], i+1):
333
+ shared_entities = set(e['text'] for e in chunk['entities']) & \
334
+ set(e['text'] for e in other_chunk['entities'])
335
+ if shared_entities:
336
+ CONCEPT_GRAPH.add_edge(i, j, weight=len(shared_entities))
337
+
338
+ print("All indices built successfully!")
339
+
340
+ def multi_stage_retrieval(query_analysis: Dict[str, Any], top_k: int = 10) -> List[Tuple[Dict[str, Any], float]]:
341
+ """Perform multi-stage retrieval combining all techniques"""
342
+ candidates = {}
343
+
344
+ print("Performing multi-stage retrieval...")
345
+
346
+ # Stage 1: Dense retrieval with expanded queries
347
+ print("Stage 1: Dense retrieval...")
348
+ for query in query_analysis['expanded_queries'][:3]:
349
+ query_emb = create_embedding(query)
350
+ scores, indices = DENSE_INDEX.search(
351
+ query_emb.reshape(1, -1).astype('float32'),
352
+ top_k * 2
353
+ )
354
+
355
+ for idx, score in zip(indices[0], scores[0]):
356
+ if idx < len(CHUNKS_DATA):
357
+ chunk_id = CHUNKS_DATA[idx]['id']
358
+ if chunk_id not in candidates:
359
+ candidates[chunk_id] = {'chunk': CHUNKS_DATA[idx], 'scores': {}}
360
+ candidates[chunk_id]['scores']['dense'] = float(score)
361
+
362
+ # Stage 2: Sparse retrieval (BM25)
363
+ print("Stage 2: Sparse retrieval...")
364
+ query_tokens = query_analysis['original_query'].lower().split()
365
+ bm25_scores = BM25_INDEX.get_scores(query_tokens)
366
+ top_bm25_indices = np.argsort(bm25_scores)[-top_k*2:][::-1]
367
+
368
+ for idx in top_bm25_indices:
369
+ if idx < len(CHUNKS_DATA):
370
+ chunk_id = CHUNKS_DATA[idx]['id']
371
+ if chunk_id not in candidates:
372
+ candidates[chunk_id] = {'chunk': CHUNKS_DATA[idx], 'scores': {}}
373
+ candidates[chunk_id]['scores']['bm25'] = float(bm25_scores[idx])
374
+
375
+ # Stage 3: Entity-based retrieval
376
+ print("Stage 3: Entity-based retrieval...")
377
+ for entity in query_analysis['entities']:
378
+ for chunk in CHUNKS_DATA:
379
+ chunk_entity_texts = [e['text'].lower() for e in chunk['entities']]
380
+ if entity['text'].lower() in chunk_entity_texts:
381
+ chunk_id = chunk['id']
382
+ if chunk_id not in candidates:
383
+ candidates[chunk_id] = {'chunk': chunk, 'scores': {}}
384
+ candidates[chunk_id]['scores']['entity'] = \
385
+ candidates[chunk_id]['scores'].get('entity', 0) + entity['importance']
386
+
387
+ # Stage 4: Graph-based retrieval
388
+ print("Stage 4: Graph-based retrieval...")
389
+ if candidates and CONCEPT_GRAPH:
390
+ seed_chunks = []
391
+ for chunk_id, data in list(candidates.items())[:5]:
392
+ for i, chunk in enumerate(CHUNKS_DATA):
393
+ if chunk['id'] == chunk_id:
394
+ seed_chunks.append(i)
395
+ break
396
+
397
+ for seed_idx in seed_chunks:
398
+ if seed_idx in CONCEPT_GRAPH:
399
+ neighbors = list(CONCEPT_GRAPH.neighbors(seed_idx))[:3]
400
+ for neighbor_idx in neighbors:
401
+ if neighbor_idx < len(CHUNKS_DATA):
402
+ chunk = CHUNKS_DATA[neighbor_idx]
403
+ chunk_id = chunk['id']
404
+ if chunk_id not in candidates:
405
+ candidates[chunk_id] = {'chunk': chunk, 'scores': {}}
406
+ candidates[chunk_id]['scores']['graph'] = 0.5
407
+
408
+ # Combine scores
409
+ print("Combining scores...")
410
+ weights = {'dense': 0.35, 'bm25': 0.25, 'entity': 0.25, 'graph': 0.15}
411
+ final_scores = []
412
+
413
+ for chunk_id, data in candidates.items():
414
+ chunk = data['chunk']
415
+ scores = data['scores']
416
+
417
+ final_score = 0
418
+ for method, weight in weights.items():
419
+ if method in scores:
420
+ # Normalize scores
421
+ if method == 'dense':
422
+ normalized = (scores[method] + 1) / 2 # [-1, 1] to [0, 1]
423
+ elif method == 'bm25':
424
+ normalized = min(scores[method] / 10, 1)
425
+ elif method == 'entity':
426
+ normalized = min(scores[method] / 3, 1)
427
+ else:
428
+ normalized = scores[method]
429
+
430
+ final_score += weight * normalized
431
+
432
+ # Boost by importance and section relevance
433
+ final_score *= chunk['importance_score']
434
+
435
+ if query_analysis['query_type'] == 'precedent' and chunk['section_type'] == 'holding':
436
+ final_score *= 1.5
437
+ elif query_analysis['query_type'] == 'factual' and chunk['section_type'] == 'facts':
438
+ final_score *= 1.5
439
+
440
+ final_scores.append((chunk, final_score))
441
+
442
+ # Sort and return top-k
443
+ final_scores.sort(key=lambda x: x[1], reverse=True)
444
+ return final_scores[:top_k]
445
+
446
+ def generate_answer_with_reasoning(query: str, retrieved_chunks: List[Tuple[Dict[str, Any], float]]) -> Dict[str, Any]:
447
+ """Generate answer with legal reasoning"""
448
+ if not GROQ_CLIENT:
449
+ return {'error': 'Groq client not initialized'}
450
+
451
+ # Prepare context
452
+ context_parts = []
453
+ for i, (chunk, score) in enumerate(retrieved_chunks, 1):
454
+ entities = ', '.join([e['text'] for e in chunk['entities'][:3]])
455
+ context_parts.append(f"""
456
+ Document {i} [{chunk['title']}] - Relevance: {score:.2f}
457
+ Section: {chunk['section_type']}
458
+ Key Entities: {entities}
459
+ Content: {chunk['text'][:800]}
460
+ """)
461
+
462
+ context = "\n---\n".join(context_parts)
463
+
464
+ system_prompt = """You are an expert legal analyst. Provide thorough legal analysis using the IRAC method:
465
+ 1. ISSUE: Identify the legal issue(s)
466
+ 2. RULE: State the applicable legal rules/precedents
467
+ 3. APPLICATION: Apply the rules to the facts
468
+ 4. CONCLUSION: Provide a clear conclusion
469
+
470
+ CRITICAL: Base ALL responses on the provided document excerpts only. Quote directly when making claims.
471
+ If information is not in the excerpts, state "This information is not provided in the available documents."
472
+ """
473
+
474
+ user_prompt = f"""Query: {query}
475
+
476
+ Retrieved Legal Documents:
477
+ {context}
478
+
479
+ Please provide a comprehensive legal analysis using IRAC method. Cite the documents when making claims."""
480
+
481
+ try:
482
+ response = GROQ_CLIENT.chat.completions.create(
483
+ messages=[
484
+ {"role": "system", "content": system_prompt},
485
+ {"role": "user", "content": user_prompt}
486
+ ],
487
+ model="llama-3.1-8b-instant",
488
+ temperature=0.1,
489
+ max_tokens=1000
490
+ )
491
+
492
+ answer = response.choices[0].message.content
493
+
494
+ # Calculate confidence
495
+ avg_score = sum(score for _, score in retrieved_chunks[:3]) / min(3, len(retrieved_chunks))
496
+ confidence = min(avg_score * 100, 100)
497
+
498
+ return {
499
+ 'answer': answer,
500
+ 'confidence': confidence,
501
+ 'sources': [
502
+ {
503
+ 'chunk_id': chunk['id'],
504
+ 'title': chunk['title'],
505
+ 'section': chunk['section_type'],
506
+ 'relevance_score': float(score),
507
+ 'excerpt': chunk['text'][:200] + '...',
508
+ 'entities': [e['text'] for e in chunk['entities'][:5]]
509
+ }
510
+ for chunk, score in retrieved_chunks
511
+ ]
512
+ }
513
+
514
+ except Exception as e:
515
+ return {
516
+ 'error': f'Error generating answer: {str(e)}',
517
+ 'sources': [{'chunk': c['text'][:200], 'score': s} for c, s in retrieved_chunks[:3]]
518
+ }
519
+
520
+ # Main functions for external use
521
+ def process_documents(documents: List[Dict[str, str]]) -> Dict[str, Any]:
522
+ """Process documents and build indices"""
523
+ all_chunks = []
524
+
525
+ for doc in documents:
526
+ chunks = chunk_text_hierarchical(doc['text'], doc.get('title', 'Document'))
527
+ all_chunks.extend(chunks)
528
+
529
+ build_all_indices(all_chunks)
530
+
531
+ return {
532
+ 'success': True,
533
+ 'chunk_count': len(all_chunks),
534
+ 'message': f'Processed {len(documents)} documents into {len(all_chunks)} chunks'
535
+ }
536
+
537
+ def query_documents(query: str, top_k: int = 5) -> Dict[str, Any]:
538
+ """Main query function - takes query, returns answer with sources"""
539
+ if not CHUNKS_DATA:
540
+ return {'error': 'No documents indexed. Call process_documents first.'}
541
+
542
+ # Analyze query
543
+ query_analysis = analyze_query(query)
544
+
545
+ # Multi-stage retrieval
546
+ retrieved_chunks = multi_stage_retrieval(query_analysis, top_k)
547
+
548
+ if not retrieved_chunks:
549
+ return {
550
+ 'error': 'No relevant documents found',
551
+ 'query_analysis': query_analysis
552
+ }
553
+
554
+ # Generate answer
555
+ result = generate_answer_with_reasoning(query, retrieved_chunks)
556
+ result['query_analysis'] = query_analysis
557
+
558
+ return result
559
+
560
+ def search_chunks_simple(query: str, top_k: int = 3) -> List[Dict[str, Any]]:
561
+ """Simple search function for compatibility"""
562
+ if not CHUNKS_DATA:
563
+ return []
564
+
565
+ query_analysis = analyze_query(query)
566
+ retrieved_chunks = multi_stage_retrieval(query_analysis, top_k)
567
+
568
+ results = []
569
+ for chunk, score in retrieved_chunks:
570
+ results.append({
571
+ 'chunk': {
572
+ 'id': chunk['id'],
573
+ 'text': chunk['text'],
574
+ 'title': chunk['title']
575
+ },
576
+ 'score': score
577
+ })
578
+
579
+ return results
580
+
581
+ def generate_conservative_answer(query: str, context_chunks: List[Dict[str, Any]]) -> str:
582
+ """Generate conservative answer - for compatibility"""
583
+ if not context_chunks:
584
+ return "No relevant information found."
585
+
586
+ # Convert format
587
+ retrieved_chunks = [(chunk['chunk'], chunk['score']) for chunk in context_chunks]
588
+ result = generate_answer_with_reasoning(query, retrieved_chunks)
589
+
590
+ if 'error' in result:
591
+ return result['error']
592
+
593
+ return result.get('answer', 'Unable to generate answer.')
simple/summarizer.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ from groq import Groq
4
+ import re
5
+ from nltk.tokenize import sent_tokenize
6
+ import nltk
7
+
8
+ # Download required NLTK data
9
+ try:
10
+ nltk.download('punkt', quiet=True)
11
+ nltk.download('punkt_tab', quiet=True)
12
+ except:
13
+ pass
14
+
15
+ def summarize_legal_document(text, max_sentences=5, groq_api_key=None, model_path=None):
16
+ """
17
+ Summarize legal document text
18
+
19
+ Args:
20
+ text: Input text to summarize
21
+ max_sentences: Maximum number of sentences in summary
22
+ groq_api_key: Optional Groq API key for enhanced summarization
23
+ model_path: Optional custom model path
24
+
25
+ Returns:
26
+ Dictionary with summary and metadata
27
+ """
28
+ if not text or not text.strip():
29
+ return {"error": "Empty text provided", "success": False}
30
+
31
+ max_sentences = max(3, min(max_sentences, 20))
32
+
33
+ # Initialize result
34
+ result = {
35
+ "original_length": len(text),
36
+ "word_count": len(text.split()),
37
+ "sentence_count": len(sent_tokenize(text)),
38
+ "success": False
39
+ }
40
+
41
+ try:
42
+ # Always generate extractive summary
43
+ extractive_summary = _extractive_summarize(text, max_sentences)
44
+ result["summary"] = extractive_summary
45
+
46
+ # Try Groq enhancement
47
+ if groq_api_key:
48
+ try:
49
+ groq_summary = _groq_summarize(text, max_sentences, groq_api_key)
50
+ if groq_summary:
51
+ result["summary"] = groq_summary
52
+ except Exception:
53
+ pass
54
+
55
+ # Calculate final metrics
56
+ final_summary = result.get("summary", "")
57
+ result["summary_length"] = len(final_summary)
58
+ result["compression_ratio"] = (
59
+ result["summary_length"] / result["original_length"]
60
+ if result["original_length"] > 0 else 0
61
+ )
62
+ result["success"] = True
63
+
64
+ except Exception as e:
65
+ result["error"] = str(e)
66
+ result["success"] = False
67
+
68
+ return result
69
+
70
+ def _extractive_summarize(text, max_sentences):
71
+ """Extract key sentences based on legal document scoring"""
72
+ sentences = sent_tokenize(text)
73
+
74
+ if len(sentences) <= max_sentences:
75
+ return text
76
+
77
+ legal_keywords = [
78
+ 'court', 'judge', 'plaintiff', 'defendant', 'appellant', 'respondent',
79
+ 'held', 'ruled', 'decided', 'judgment', 'order', 'section', 'article',
80
+ 'provision', 'law', 'legal', 'case', 'appeal', 'petition', 'writ',
81
+ 'contract', 'agreement', 'liability', 'damages', 'evidence', 'witness',
82
+ 'statute', 'regulation', 'finding', 'conclusion', 'reasoning'
83
+ ]
84
+
85
+ sentence_scores = []
86
+
87
+ for i, sentence in enumerate(sentences):
88
+ if not sentence.strip():
89
+ continue
90
+
91
+ score = 0
92
+ sentence_lower = sentence.lower()
93
+
94
+ # Keyword scoring
95
+ for keyword in legal_keywords:
96
+ if keyword in sentence_lower:
97
+ score += 1
98
+
99
+ # Position scoring
100
+ if i == 0:
101
+ score += 3
102
+ elif i == len(sentences) - 1:
103
+ score += 2
104
+ elif i < len(sentences) * 0.2:
105
+ score += 1
106
+
107
+ # Length scoring
108
+ word_count = len(sentence.split())
109
+ if 15 <= word_count <= 40:
110
+ score += 2
111
+ elif 10 <= word_count <= 50:
112
+ score += 1
113
+
114
+ # Numbers and dates
115
+ if re.search(r'\b\d{4}\b|\b\d+\s*(percent|%|\$)', sentence):
116
+ score += 1
117
+
118
+ # Legal citations
119
+ if re.search(r'\d+\s+[A-Z][a-z]+\.?\s+\d+|\bv\.\s+[A-Z]', sentence):
120
+ score += 2
121
+
122
+ sentence_scores.append((score, i, sentence))
123
+
124
+ # Select top sentences
125
+ sentence_scores.sort(reverse=True, key=lambda x: x[0])
126
+ selected_sentences = sentence_scores[:max_sentences]
127
+
128
+ # Sort by original order
129
+ selected_sentences.sort(key=lambda x: x[1])
130
+
131
+ return ' '.join([sent[2] for sent in selected_sentences])
132
+
133
+ def _groq_summarize(text, max_sentences, api_key):
134
+ """Enhanced summarization using Groq LLM"""
135
+ try:
136
+ client = Groq(api_key=api_key)
137
+
138
+ # Truncate if too long
139
+ if len(text) > 6000:
140
+ text = text[:6000] + "\n[...text truncated...]"
141
+
142
+ system_prompt = """You are an expert legal document summarizer. Create concise, accurate summaries that capture the most important information.
143
+
144
+ Guidelines:
145
+ 1. Focus on key legal facts, holdings, and conclusions
146
+ 2. Preserve important legal terminology and concepts
147
+ 3. Maintain logical flow of legal reasoning
148
+ 4. Include relevant case citations, statutes, or regulations
149
+ 5. Be precise and avoid unnecessary elaboration"""
150
+
151
+ user_prompt = f"""Please summarize the following legal document in approximately {max_sentences} sentences:
152
+
153
+ {text}
154
+
155
+ Provide a clear, concise summary:"""
156
+
157
+ response = client.chat.completions.create(
158
+ messages=[
159
+ {"role": "system", "content": system_prompt},
160
+ {"role": "user", "content": user_prompt}
161
+ ],
162
+ model="llama-3.1-8b-instant",
163
+ temperature=0.2,
164
+ max_tokens=800,
165
+ top_p=0.9
166
+ )
167
+
168
+ summary = response.choices[0].message.content.strip()
169
+ if summary and len(summary) > 20:
170
+ return summary
171
+
172
+ except Exception:
173
+ pass
174
+
175
+ return None
176
+
177
+ def _chunk_text(text, max_words):
178
+ """Split text into chunks for processing"""
179
+ words = text.split()
180
+ chunks = []
181
+
182
+ for i in range(0, len(words), max_words):
183
+ chunk_words = words[i:i + max_words]
184
+ if chunk_words:
185
+ chunks.append(' '.join(chunk_words))
186
+
187
+ return chunks