Spaces:
Sleeping
Sleeping
| import os | |
| import asyncio | |
| import uuid | |
| from datetime import datetime, timedelta | |
| from typing import Dict, Any, List, Optional | |
| import logging | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| import uvicorn | |
| from motor.motor_asyncio import AsyncIOMotorClient | |
| import pymongo | |
| from pymongo import ASCENDING | |
| import PyPDF2 | |
| import docx | |
| import io | |
| from PIL import Image | |
| import pytesseract | |
| # Import our models | |
| from simple.rag import initialize_models, process_documents, create_embedding, chunk_text_hierarchical | |
| from simple.ner import process_text as run_ner | |
| from simple.summarizer import summarize_legal_document | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global variables | |
| mongodb_client: Optional[AsyncIOMotorClient] = None | |
| db = None | |
| cleanup_task = None | |
| # Configuration | |
| MONGODB_URI = os.getenv("MONGODB_URI", "mongodb+srv://username:password@cluster.mongodb.net/") | |
| DATABASE_NAME = os.getenv("DATABASE_NAME", "legal_rag_system") | |
| # Hardcode embedding model per request | |
| HF_MODEL_ID = "sentence-transformers/all-MiniLM-L6-v2" | |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY", None) | |
| SESSION_EXPIRE_HOURS = int(os.getenv("SESSION_EXPIRE_HOURS", "24")) | |
| # Optional HF token (if NER model is private) | |
| HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN") or os.getenv("HF_TOKEN") | |
| # Supported file types | |
| SUPPORTED_EXTENSIONS = {'.pdf', '.txt', '.docx', '.doc'} | |
| MAX_FILE_SIZE = 50 * 1024 * 1024 # 50MB | |
| async def lifespan(app: FastAPI): | |
| """Application lifespan manager""" | |
| # Startup | |
| await startup_event() | |
| yield | |
| # Shutdown | |
| await shutdown_event() | |
| app = FastAPI( | |
| title="Legal Document Processor", | |
| description="Process legal documents with NER, summarization, and embeddings", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Configure this properly for production | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| """Initialize services on startup""" | |
| global mongodb_client, db, cleanup_task | |
| try: | |
| logger.info("π Starting up Legal Document Processor...") | |
| # Initialize MongoDB | |
| logger.info("π Connecting to MongoDB...") | |
| mongodb_client = AsyncIOMotorClient(MONGODB_URI) | |
| db = mongodb_client[DATABASE_NAME] | |
| # Test connection | |
| await mongodb_client.admin.command('ping') | |
| logger.info("β MongoDB connected successfully") | |
| # Create indexes | |
| await create_indexes() | |
| # Initialize ML models (embeddings / retrieval backbone) | |
| logger.info(f"π€ Loading embedding model for RAG: {HF_MODEL_ID}") | |
| initialize_models(HF_MODEL_ID, GROQ_API_KEY) | |
| logger.info(f"β Embedding model loaded: {HF_MODEL_ID}") | |
| # Surface NER token presence (actual NER loads lazily in simple.ner) | |
| if HUGGINGFACE_TOKEN: | |
| os.environ["HUGGINGFACE_TOKEN"] = HUGGINGFACE_TOKEN | |
| logger.info("π HUGGINGFACE_TOKEN detected for NER model access") | |
| else: | |
| logger.info("βΉοΈ No HUGGINGFACE_TOKEN provided (NER model assumed public)") | |
| # Eagerly load and validate NER model once on startup for peace of mind | |
| try: | |
| ner_model_id = "kn29/my-ner-model" | |
| logger.info(f"π§ Preloading NER model: {ner_model_id}") | |
| _ = run_ner("Warmup NER model load.", model_id=ner_model_id) | |
| logger.info(f"β NER model ready: {ner_model_id}") | |
| except Exception as e: | |
| logger.error(f"β NER preload failed: {str(e)}") | |
| # Start cleanup task | |
| cleanup_task = asyncio.create_task(periodic_cleanup()) | |
| logger.info("π§Ή Cleanup task started") | |
| logger.info("π Startup completed successfully!") | |
| except Exception as e: | |
| logger.error(f"β Startup failed: {str(e)}") | |
| raise | |
| async def shutdown_event(): | |
| """Cleanup on shutdown""" | |
| global mongodb_client, cleanup_task | |
| logger.info("π Shutting down...") | |
| if cleanup_task: | |
| cleanup_task.cancel() | |
| try: | |
| await cleanup_task | |
| except asyncio.CancelledError: | |
| pass | |
| if mongodb_client: | |
| mongodb_client.close() | |
| logger.info("β Shutdown completed") | |
| async def create_indexes(): | |
| """Create MongoDB indexes for optimal performance""" | |
| try: | |
| # Sessions collection indexes | |
| await db.sessions.create_index([("session_id", ASCENDING)], unique=True) | |
| await db.sessions.create_index([("created_at", ASCENDING)], expireAfterSeconds=SESSION_EXPIRE_HOURS * 3600) | |
| await db.sessions.create_index([("status", ASCENDING)]) | |
| # Chunks collection indexes | |
| await db.chunks.create_index([("session_id", ASCENDING)]) | |
| await db.chunks.create_index([("chunk_id", ASCENDING)]) | |
| await db.chunks.create_index([("created_at", ASCENDING)], expireAfterSeconds=SESSION_EXPIRE_HOURS * 3600) | |
| # NER results collection indexes | |
| await db.ner_results.create_index([("session_id", ASCENDING)]) | |
| await db.ner_results.create_index([("created_at", ASCENDING)], expireAfterSeconds=SESSION_EXPIRE_HOURS * 3600) | |
| # Summaries collection indexes | |
| await db.summaries.create_index([("session_id", ASCENDING)]) | |
| await db.summaries.create_index([("created_at", ASCENDING)], expireAfterSeconds=SESSION_EXPIRE_HOURS * 3600) | |
| logger.info("π Database indexes created successfully") | |
| except Exception as e: | |
| logger.error(f"β Failed to create indexes: {str(e)}") | |
| async def periodic_cleanup(): | |
| """Periodically clean up expired sessions""" | |
| while True: | |
| try: | |
| await asyncio.sleep(3600) # Run every hour | |
| await cleanup_expired_sessions() | |
| except asyncio.CancelledError: | |
| break | |
| except Exception as e: | |
| logger.error(f"β Cleanup task error: {str(e)}") | |
| async def cleanup_expired_sessions(): | |
| """Clean up expired sessions from MongoDB""" | |
| try: | |
| cutoff_time = datetime.utcnow() - timedelta(hours=SESSION_EXPIRE_HOURS) | |
| # Count expired sessions | |
| expired_count = await db.sessions.count_documents({ | |
| "created_at": {"$lt": cutoff_time} | |
| }) | |
| if expired_count > 0: | |
| # Delete expired sessions and related data | |
| await db.sessions.delete_many({"created_at": {"$lt": cutoff_time}}) | |
| await db.chunks.delete_many({"created_at": {"$lt": cutoff_time}}) | |
| await db.ner_results.delete_many({"created_at": {"$lt": cutoff_time}}) | |
| await db.summaries.delete_many({"created_at": {"$lt": cutoff_time}}) | |
| logger.info(f"π§Ή Cleaned up {expired_count} expired sessions") | |
| except Exception as e: | |
| logger.error(f"β Cleanup failed: {str(e)}") | |
| def extract_text_from_file(file_content: bytes, filename: str) -> str: | |
| """Extract text from various file formats""" | |
| file_ext = os.path.splitext(filename.lower())[1] | |
| try: | |
| if file_ext == '.pdf': | |
| return extract_text_from_pdf(file_content) | |
| elif file_ext == '.txt': | |
| return file_content.decode('utf-8', errors='ignore') | |
| elif file_ext in ['.docx', '.doc']: | |
| return extract_text_from_docx(file_content) | |
| else: | |
| raise ValueError(f"Unsupported file type: {file_ext}") | |
| except Exception as e: | |
| logger.error(f"β Text extraction failed for {filename}: {str(e)}") | |
| raise | |
| def extract_text_from_pdf(file_content: bytes) -> str: | |
| """Extract text from PDF file""" | |
| try: | |
| pdf_file = io.BytesIO(file_content) | |
| pdf_reader = PyPDF2.PdfReader(pdf_file) | |
| text = "" | |
| for page in pdf_reader.pages: | |
| text += page.extract_text() + "\n" | |
| if not text.strip(): | |
| # Try OCR if no text extracted | |
| logger.info("π· No text found in PDF, attempting OCR...") | |
| # This would require additional setup for OCR | |
| text = "OCR extraction not implemented yet" | |
| return text | |
| except Exception as e: | |
| logger.error(f"β PDF extraction failed: {str(e)}") | |
| raise | |
| def extract_text_from_docx(file_content: bytes) -> str: | |
| """Extract text from DOCX file""" | |
| try: | |
| doc_file = io.BytesIO(file_content) | |
| doc = docx.Document(doc_file) | |
| text = "" | |
| for paragraph in doc.paragraphs: | |
| text += paragraph.text + "\n" | |
| return text | |
| except Exception as e: | |
| logger.error(f"β DOCX extraction failed: {str(e)}") | |
| raise | |
| async def process_document_pipeline( | |
| session_id: str, | |
| text: str, | |
| filename: str, | |
| background_tasks: BackgroundTasks | |
| ): | |
| """Process document through the complete pipeline""" | |
| try: | |
| logger.info(f"π Starting processing pipeline for session {session_id}") | |
| # Update session status | |
| await db.sessions.update_one( | |
| {"session_id": session_id}, | |
| {"$set": {"status": "processing", "updated_at": datetime.utcnow()}} | |
| ) | |
| # Step 1: NER Processing (spaCy pipeline from Hugging Face) | |
| ner_model_id = "kn29/my-ner-model" | |
| logger.info(f"π Running NER for session {session_id} using model: {ner_model_id}") | |
| ner_results = run_ner( | |
| text, | |
| model_id=ner_model_id | |
| ) | |
| if ner_results.get("error"): | |
| logger.error(f"β NER failed for session {session_id}: {ner_results['error']}") | |
| else: | |
| logger.info( | |
| f"β NER completed for session {session_id} β’ total_entities={ner_results.get('total_entities', 0)} β’ labels={len(ner_results.get('unique_labels', []))}" | |
| ) | |
| # Store NER results | |
| await db.ner_results.insert_one({ | |
| "session_id": session_id, | |
| "filename": filename, | |
| "results": ner_results, | |
| "created_at": datetime.utcnow() | |
| }) | |
| # Step 2: Summarization | |
| logger.info(f"π Running summarization for session {session_id} (Groq={'on' if GROQ_API_KEY else 'off'})") | |
| summary_results = summarize_legal_document( | |
| text, | |
| max_sentences=5, | |
| groq_api_key=GROQ_API_KEY | |
| ) | |
| # Store summary results | |
| await db.summaries.insert_one({ | |
| "session_id": session_id, | |
| "filename": filename, | |
| "results": summary_results, | |
| "created_at": datetime.utcnow() | |
| }) | |
| # Step 3: Chunking and Embedding | |
| logger.info(f"π§© Creating chunks and embeddings for session {session_id} using {HF_MODEL_ID}") | |
| chunks = chunk_text_hierarchical(text, filename) | |
| logger.info(f"π Created {len(chunks)} chunks from document") | |
| # Create embeddings and store chunks | |
| chunks_to_store = [] | |
| for i, chunk in enumerate(chunks): | |
| # Validate chunk has text | |
| chunk_text = chunk.get('text', '').strip() | |
| if not chunk_text: | |
| logger.warning(f"β οΈ Skipping chunk {i} - no text content") | |
| continue | |
| # Create embedding | |
| try: | |
| embedding = create_embedding(chunk_text) | |
| except Exception as e: | |
| logger.error(f"β Failed to create embedding for chunk {i}: {e}") | |
| continue | |
| # FIXED: Use 'content' field instead of 'text' | |
| chunk_doc = { | |
| "session_id": session_id, | |
| "chunk_id": chunk['id'], | |
| "content": chunk_text, # Changed from 'text' to 'content' | |
| "title": chunk['title'], | |
| "section_type": chunk['section_type'], | |
| "importance_score": chunk['importance_score'], | |
| "entities": chunk['entities'], | |
| "embedding": embedding.tolist(), | |
| "created_at": datetime.utcnow() | |
| } | |
| chunks_to_store.append(chunk_doc) | |
| # Batch insert chunks | |
| if chunks_to_store: | |
| await db.chunks.insert_many(chunks_to_store) | |
| logger.info(f"β Stored {len(chunks_to_store)} chunks with embeddings") | |
| else: | |
| raise Exception("No valid chunks created from document") | |
| # Update session as completed | |
| await db.sessions.update_one( | |
| {"session_id": session_id}, | |
| { | |
| "$set": { | |
| "status": "completed", | |
| "updated_at": datetime.utcnow(), | |
| "chunk_count": len(chunks_to_store), | |
| "processing_completed_at": datetime.utcnow() | |
| } | |
| } | |
| ) | |
| logger.info(f"β Processing completed for session {session_id}") | |
| except Exception as e: | |
| logger.error(f"β Processing failed for session {session_id}: {str(e)}") | |
| # Update session with error | |
| await db.sessions.update_one( | |
| {"session_id": session_id}, | |
| { | |
| "$set": { | |
| "status": "failed", | |
| "error": str(e), | |
| "updated_at": datetime.utcnow() | |
| } | |
| } | |
| ) | |
| async def upload_document( | |
| background_tasks: BackgroundTasks, | |
| file: UploadFile = File(...) | |
| ): | |
| """Upload and process a legal document""" | |
| try: | |
| # Validate file | |
| if not file.filename: | |
| raise HTTPException(status_code=400, detail="No file provided") | |
| file_ext = os.path.splitext(file.filename.lower())[1] | |
| if file_ext not in SUPPORTED_EXTENSIONS: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Unsupported file type. Supported: {', '.join(SUPPORTED_EXTENSIONS)}" | |
| ) | |
| # Check file size | |
| file_content = await file.read() | |
| if len(file_content) > MAX_FILE_SIZE: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"File too large. Maximum size: {MAX_FILE_SIZE // (1024*1024)}MB" | |
| ) | |
| # Generate session ID | |
| session_id = str(uuid.uuid4()) | |
| # Extract text | |
| logger.info(f"π Extracting text from {file.filename}") | |
| text = extract_text_from_file(file_content, file.filename) | |
| if not text.strip(): | |
| raise HTTPException(status_code=400, detail="No text could be extracted from the file") | |
| # Create session record | |
| session_doc = { | |
| "session_id": session_id, | |
| "filename": file.filename, | |
| "file_size": len(file_content), | |
| "text_length": len(text), | |
| "word_count": len(text.split()), | |
| "status": "uploaded", | |
| "created_at": datetime.utcnow(), | |
| "updated_at": datetime.utcnow() | |
| } | |
| await db.sessions.insert_one(session_doc) | |
| # Start background processing | |
| background_tasks.add_task( | |
| process_document_pipeline, | |
| session_id, | |
| text, | |
| file.filename, | |
| background_tasks | |
| ) | |
| logger.info(f"β Document uploaded successfully. Session ID: {session_id}") | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "success": True, | |
| "session_id": session_id, | |
| "filename": file.filename, | |
| "file_size": len(file_content), | |
| "text_length": len(text), | |
| "word_count": len(text.split()), | |
| "status": "processing", | |
| "message": "Document uploaded successfully. Processing started." | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"β Upload failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}") | |
| async def get_session_status(session_id: str): | |
| """Get the processing status of a session""" | |
| try: | |
| session = await db.sessions.find_one({"session_id": session_id}) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| # --- FIX: Convert all datetime objects to ISO 8601 strings --- | |
| session["_id"] = str(session["_id"]) | |
| if session.get("created_at"): | |
| session["created_at"] = session["created_at"].isoformat() | |
| if session.get("updated_at"): | |
| session["updated_at"] = session["updated_at"].isoformat() | |
| if session.get("processing_completed_at"): | |
| session["processing_completed_at"] = session["processing_completed_at"].isoformat() | |
| # Add processing progress info | |
| if session["status"] == "completed": | |
| # Get additional info | |
| ner_result = await db.ner_results.find_one({"session_id": session_id}) | |
| summary_result = await db.summaries.find_one({"session_id": session_id}) | |
| chunk_count = await db.chunks.count_documents({"session_id": session_id}) | |
| session["ner_entities"] = ner_result["results"]["total_entities"] if ner_result else 0 | |
| session["summary_available"] = bool(summary_result) | |
| session["chunk_count"] = chunk_count | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "success": True, | |
| "session": session | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"β Status check failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Status check failed: {str(e)}") | |
| async def get_processing_results(session_id: str): | |
| """Get all processing results for a session""" | |
| try: | |
| # Check if session exists and is completed | |
| session = await db.sessions.find_one({"session_id": session_id}) | |
| if not session: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| if session["status"] != "completed": | |
| return JSONResponse( | |
| status_code=202, | |
| content={ | |
| "success": False, | |
| "message": f"Processing not completed. Current status: {session['status']}" | |
| } | |
| ) | |
| # Get NER results | |
| ner_result = await db.ner_results.find_one({"session_id": session_id}) | |
| # Get summary results | |
| summary_result = await db.summaries.find_one({"session_id": session_id}) | |
| # Get chunk metadata (not full text) | |
| chunks_cursor = db.chunks.find( | |
| {"session_id": session_id}, | |
| {"text": 0, "embedding": 0} # Exclude large fields | |
| ) | |
| chunks_metadata = await chunks_cursor.to_list(length=None) | |
| # --- FIX: Convert datetime objects to ISO strings --- | |
| # Clean up ObjectIds and datetime objects in chunks | |
| for chunk in chunks_metadata: | |
| chunk["_id"] = str(chunk["_id"]) | |
| if chunk.get("created_at"): | |
| chunk["created_at"] = chunk["created_at"].isoformat() | |
| # Clean up NER result datetime objects | |
| if ner_result: | |
| ner_result["_id"] = str(ner_result["_id"]) | |
| if ner_result.get("created_at"): | |
| ner_result["created_at"] = ner_result["created_at"].isoformat() | |
| # Clean up summary result datetime objects | |
| if summary_result: | |
| summary_result["_id"] = str(summary_result["_id"]) | |
| if summary_result.get("created_at"): | |
| summary_result["created_at"] = summary_result["created_at"].isoformat() | |
| # Convert session datetime objects | |
| processing_completed_at = session.get("processing_completed_at") | |
| if processing_completed_at: | |
| processing_completed_at = processing_completed_at.isoformat() | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "success": True, | |
| "session_id": session_id, | |
| "filename": session["filename"], | |
| "ner_results": ner_result["results"] if ner_result else None, | |
| "summary_results": summary_result["results"] if summary_result else None, | |
| "chunks_metadata": { | |
| "total_chunks": len(chunks_metadata), | |
| "chunks": chunks_metadata[:10] # Return first 10 chunks metadata | |
| }, | |
| "processing_completed_at": processing_completed_at | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"β Results retrieval failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Results retrieval failed: {str(e)}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| try: | |
| # Test MongoDB connection | |
| await mongodb_client.admin.command('ping') | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "status": "healthy", | |
| "timestamp": datetime.utcnow().isoformat(), | |
| "services": { | |
| "mongodb": "connected", | |
| "ml_models": "loaded" | |
| } | |
| } | |
| ) | |
| except Exception as e: | |
| logger.error(f"β Health check failed: {str(e)}") | |
| return JSONResponse( | |
| status_code=503, | |
| content={ | |
| "status": "unhealthy", | |
| "error": str(e), | |
| "timestamp": datetime.utcnow().isoformat() | |
| } | |
| ) | |
| async def ner_health_check(): | |
| """Verify NER model can load and process a tiny input.""" | |
| try: | |
| ner_model_id = "kn29/my-ner-model" | |
| result = run_ner("Test entity: Supreme Court.", model_id=ner_model_id) | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "status": "ready", | |
| "model_id": ner_model_id, | |
| "total_entities": result.get("total_entities", 0), | |
| "labels": result.get("unique_labels", []), | |
| } | |
| ) | |
| except Exception as e: | |
| return JSONResponse( | |
| status_code=503, | |
| content={ | |
| "status": "error", | |
| "error": str(e) | |
| } | |
| ) | |
| async def delete_session(session_id: str): | |
| """Manually delete a session and all related data""" | |
| try: | |
| # Delete from all collections | |
| session_result = await db.sessions.delete_one({"session_id": session_id}) | |
| await db.chunks.delete_many({"session_id": session_id}) | |
| await db.ner_results.delete_many({"session_id": session_id}) | |
| await db.summaries.delete_many({"session_id": session_id}) | |
| if session_result.deleted_count == 0: | |
| raise HTTPException(status_code=404, detail="Session not found") | |
| return JSONResponse( | |
| status_code=200, | |
| content={ | |
| "success": True, | |
| "message": f"Session {session_id} deleted successfully" | |
| } | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"β Session deletion failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Session deletion failed: {str(e)}") | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "service": "Legal Document Processor", | |
| "version": "1.0.0", | |
| "status": "running", | |
| "endpoints": { | |
| "upload": "POST /upload - Upload a legal document for processing", | |
| "status": "GET /status/{session_id} - Check processing status", | |
| "results": "GET /results/{session_id} - Get processing results", | |
| "health": "GET /health - Health check", | |
| "delete": "DELETE /session/{session_id} - Delete a session" | |
| }, | |
| "supported_formats": list(SUPPORTED_EXTENSIONS) | |
| } | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", 7860)) | |
| uvicorn.run( | |
| "app:app", | |
| host="0.0.0.0", | |
| port=port, | |
| reload=False, | |
| access_log=True | |
| ) |