Spaces:
Running
Running
| # Updated main.py with static file serving for production | |
| from fastapi import FastAPI, UploadFile, File, WebSocket, WebSocketDisconnect, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse, FileResponse | |
| import os | |
| import json | |
| import uuid | |
| import asyncio | |
| import logging | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Dict, List, Any | |
| import tempfile | |
| import shutil | |
| from multi_agent_ocr import analyze_document_with_flexible_supervision | |
| from pdf2image import convert_from_path | |
| from PIL import Image | |
| import io | |
| from starlette.responses import FileResponse, HTMLResponse | |
| from starlette.exceptions import HTTPException as StarletteHTTPException | |
| from logging.handlers import QueueHandler, QueueListener | |
| import queue | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI( | |
| title="OCR/LAD/RAD Interface", | |
| version="1.0.0", | |
| description="Advanced document analysis with AI agents" | |
| ) | |
| # CORS middleware for development | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://localhost:3000", "http://127.0.0.1:3000"], # React dev server | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables | |
| upload_directory = Path("uploads") | |
| upload_directory.mkdir(exist_ok=True) | |
| # Document storage (in production, use a database) | |
| documents_store: Dict[str, Dict[str, Any]] = {} | |
| # --- WebSocket Logging Setup --- | |
| # 1. Create a queue to hold log records | |
| log_queue = queue.Queue() | |
| # 2. Create a custom handler that puts logs into the queue | |
| class WebSocketLogHandler(logging.Handler): | |
| def __init__(self, queue): | |
| super().__init__() | |
| self.queue = queue | |
| def emit(self, record): | |
| # Only process logs that we want to show in the frontend | |
| if "HTTP Request" not in record.getMessage() and "uvicorn" not in record.name: | |
| log_entry = self.format(record) | |
| self.queue.put(log_entry) | |
| # 3. Get the root logger and add the queue handler | |
| root_logger = logging.getLogger() | |
| # Ensure we capture INFO level logs and above | |
| root_logger.setLevel(logging.INFO) | |
| # Use a specific format for frontend logs | |
| formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%H:%M:%S') | |
| queue_handler = WebSocketLogHandler(log_queue) | |
| queue_handler.setFormatter(formatter) | |
| root_logger.addHandler(queue_handler) | |
| # --- End WebSocket Logging Setup --- | |
| class ConnectionManager: | |
| def __init__(self): | |
| self.active_connections: List[WebSocket] = [] | |
| # Start a listener that will process the queue | |
| self.log_listener = QueueListener(log_queue, self) | |
| self.log_listener.start() | |
| async def connect(self, websocket: WebSocket): | |
| await websocket.accept() | |
| self.active_connections.append(websocket) | |
| await self.send_message_to(websocket, {"type": "system", "content": "WebSocket connection established. Streaming real-time logs..."}) | |
| def disconnect(self, websocket: WebSocket): | |
| if websocket in self.active_connections: | |
| self.active_connections.remove(websocket) | |
| async def send_message(self, message: dict): | |
| """Send message to all connected clients""" | |
| if self.active_connections: | |
| disconnected = [] | |
| for connection in self.active_connections: | |
| try: | |
| await connection.send_text(json.dumps(message)) | |
| except: | |
| disconnected.append(connection) | |
| # Remove disconnected clients | |
| for conn in disconnected: | |
| self.disconnect(conn) | |
| async def send_message_to(self, websocket: WebSocket, message: dict): | |
| await websocket.send_json(message) | |
| # This method makes the manager a valid "handler" for the QueueListener | |
| def handle(self, record): | |
| # Broadcast the log record to all connected clients | |
| # We run this in a new event loop because logging can happen in different threads | |
| asyncio.run(self.send_message({"type": "log", "content": record})) | |
| def stop(self): | |
| self.log_listener.stop() | |
| manager = ConnectionManager() | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await manager.connect(websocket) | |
| try: | |
| while True: | |
| # Keep connection alive | |
| await asyncio.sleep(1) | |
| except WebSocketDisconnect: | |
| manager.disconnect(websocket) | |
| async def upload_document(file: UploadFile = File(...)): | |
| """Upload a document for OCR processing""" | |
| # Validate file type | |
| allowed_extensions = {'.pdf', '.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.tiff'} | |
| file_extension = Path(file.filename).suffix.lower() | |
| if file_extension not in allowed_extensions: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"File type {file_extension} not supported. Allowed: {', '.join(allowed_extensions)}" | |
| ) | |
| # Generate unique ID for this document | |
| document_id = str(uuid.uuid4()) | |
| # Save file | |
| file_path = upload_directory / f"{document_id}_{file.filename}" | |
| try: | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| # Get file info | |
| file_size = file_path.stat().st_size | |
| # Store document metadata | |
| documents_store[document_id] = { | |
| "document_id": document_id, | |
| "filename": file.filename, | |
| "file_path": str(file_path), | |
| "file_size": file_size, | |
| "upload_time": datetime.now().isoformat(), | |
| "status": "uploaded", | |
| "extraction_result": None | |
| } | |
| await manager.send_message({ | |
| "type": "success", | |
| "message": f"Document uploaded: {file.filename} ({file_size / 1024:.1f} KB)", | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| return { | |
| "document_id": document_id, | |
| "filename": file.filename, | |
| "file_size": file_size, | |
| "status": "uploaded" | |
| } | |
| except Exception as e: | |
| logger.error(f"Upload error: {e}") | |
| await manager.send_message({ | |
| "type": "error", | |
| "message": f"Upload failed: {str(e)}", | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| raise HTTPException(status_code=500, detail=f"Upload failed: {str(e)}") | |
| async def delete_document(document_id: str): | |
| """Delete a document""" | |
| if document_id not in documents_store: | |
| raise HTTPException(status_code=404, detail="Document not found") | |
| try: | |
| # Delete file | |
| file_path = Path(documents_store[document_id]["file_path"]) | |
| if file_path.exists(): | |
| file_path.unlink() | |
| # Remove from store | |
| filename = documents_store[document_id]["filename"] | |
| del documents_store[document_id] | |
| await manager.send_message({ | |
| "type": "warning", | |
| "message": f"Document deleted: {filename}", | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| return {"message": "Document deleted successfully"} | |
| except Exception as e: | |
| logger.error(f"Delete error: {e}") | |
| await manager.send_message({ | |
| "type": "error", | |
| "message": f"Delete failed: {str(e)}", | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| raise HTTPException(status_code=500, detail=f"Delete failed: {str(e)}") | |
| async def extract_document(document_id: str): | |
| """Extract information from document using OCR/LAD/RAD agents""" | |
| if document_id not in documents_store: | |
| raise HTTPException(status_code=404, detail="Document not found") | |
| document = documents_store[document_id] | |
| file_path = document["file_path"] | |
| try: | |
| await manager.send_message({ | |
| "type": "info", | |
| "step": "extraction_start", | |
| "message": f"Starting OCR/LAD/RAD analysis for: {document['filename']}", | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| # Update status | |
| documents_store[document_id]["status"] = "processing" | |
| # Get API keys from environment | |
| anthropic_key = os.getenv("ANTHROPIC_API_KEY") | |
| openai_key = os.getenv("OPENAI_API_KEY") | |
| if not anthropic_key or not openai_key: | |
| raise HTTPException( | |
| status_code=500, | |
| detail="API keys not configured. Set ANTHROPIC_API_KEY and OPENAI_API_KEY environment variables." | |
| ) | |
| await manager.send_message({ | |
| "type": "info", | |
| "step": "agents_init", | |
| "message": "Initializing intelligent OCR agents...", | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| # Call the actual OCR analysis | |
| result = analyze_document_with_flexible_supervision( | |
| file_path=file_path, | |
| anthropic_api_key=anthropic_key, | |
| openai_api_key=openai_key | |
| ) | |
| # Process results for frontend | |
| processed_result = await process_extraction_result(result, document_id) | |
| # Store result | |
| documents_store[document_id]["extraction_result"] = processed_result | |
| documents_store[document_id]["status"] = "completed" | |
| await manager.send_message({ | |
| "type": "success", | |
| "step": "extraction_complete", | |
| "message": f"Analysis completed successfully! Quality score: {processed_result.get('business_logic_score', 0):.2%}", | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| return processed_result | |
| except Exception as e: | |
| logger.error(f"Extraction error: {e}") | |
| documents_store[document_id]["status"] = "error" | |
| await manager.send_message({ | |
| "type": "error", | |
| "step": "extraction_error", | |
| "message": f"Extraction failed: {str(e)}", | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| raise HTTPException(status_code=500, detail=f"Extraction failed: {str(e)}") | |
| async def process_extraction_result(result: Dict[str, Any], document_id: str) -> Dict[str, Any]: | |
| """Process and enhance extraction result for frontend display""" | |
| if result.get("error"): | |
| await manager.send_message({ | |
| "type": "error", | |
| "step": "result_error", | |
| "message": f"Extraction error: {result.get('message', 'Unknown error')}", | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| return result | |
| # Normalize field names for frontend compatibility | |
| normalized_result = {} | |
| # Copy all existing fields | |
| for key, value in result.items(): | |
| normalized_result[key] = value | |
| # Ensure score fields have the right names for frontend | |
| if 'image_quality' in result: | |
| normalized_result['image_quality_score'] = result['image_quality'] | |
| if 'business_logic' in result: | |
| normalized_result['business_logic_score'] = result['business_logic'] | |
| if 'information_relevance' in result: | |
| normalized_result['information_relevance_score'] = result['information_relevance'] | |
| # Flatten and normalize extracted_data for frontend compatibility | |
| extracted_data = result.get('extracted_data', {}) | |
| if extracted_data: | |
| # Convert key_fields to frontend format | |
| if 'key_fields' in extracted_data: | |
| key_fields = extracted_data['key_fields'] | |
| if isinstance(key_fields, dict): | |
| normalized_result['key_fields'] = [ | |
| {"field": k, "value": v, "confidence": 0.9} | |
| for k, v in key_fields.items() | |
| ] | |
| else: | |
| normalized_result['key_fields'] = key_fields | |
| # Convert dates to frontend format | |
| if 'dates' in extracted_data: | |
| dates = extracted_data['dates'] | |
| if isinstance(dates, list) and dates: | |
| if isinstance(dates[0], str): | |
| # Convert string array to object array | |
| normalized_result['dates'] = [ | |
| {"date_type": f"Date {i+1}", "date_value": date, "confidence": 0.9} | |
| for i, date in enumerate(dates) | |
| ] | |
| else: | |
| normalized_result['dates'] = dates | |
| else: | |
| normalized_result['dates'] = [] | |
| # Convert amounts to frontend format | |
| if 'amounts' in extracted_data: | |
| amounts = extracted_data['amounts'] | |
| if isinstance(amounts, list) and amounts: | |
| if isinstance(amounts[0], str): | |
| # Convert string array to object array | |
| normalized_result['amounts'] = [ | |
| {"amount_type": f"Amount {i+1}", "amount_value": amount, "currency": "β¬", "confidence": 0.9} | |
| for i, amount in enumerate(amounts) | |
| ] | |
| else: | |
| normalized_result['amounts'] = amounts | |
| else: | |
| normalized_result['amounts'] = [] | |
| # Convert entities to frontend format | |
| if 'entities' in extracted_data: | |
| entities = extracted_data['entities'] | |
| if isinstance(entities, dict): | |
| # Convert nested dict to flat array | |
| entities_array = [] | |
| for entity_type, entity_list in entities.items(): | |
| if isinstance(entity_list, list): | |
| for entity in entity_list: | |
| entities_array.append({ | |
| "entity_type": entity_type, | |
| "entity_value": entity, | |
| "confidence": 0.9 | |
| }) | |
| normalized_result['entities'] = entities_array | |
| else: | |
| normalized_result['entities'] = entities | |
| # Ensure we have the OCR text with all possible field names | |
| if 'full_text' in result: | |
| normalized_result['ocr_text'] = result['full_text'] | |
| elif 'raw_text' in result: | |
| normalized_result['ocr_text'] = result['raw_text'] | |
| elif 'text' in result: | |
| normalized_result['ocr_text'] = result['text'] | |
| else: | |
| normalized_result['ocr_text'] = "No text extracted" | |
| # Log key metrics with more precision | |
| image_quality = normalized_result.get('image_quality_score', 0) | |
| business_logic = normalized_result.get('business_logic_score', 0) | |
| info_relevance = normalized_result.get('information_relevance_score', 0) | |
| await manager.send_message({ | |
| "type": "info", | |
| "step": "quality_metrics", | |
| "message": f"Image Quality: {image_quality:.1%} | Business Logic: {business_logic:.1%} | Information Relevance: {info_relevance:.1%}", | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| # Log extraction summary | |
| key_fields_count = len(normalized_result.get('key_fields', [])) | |
| dates_count = len(normalized_result.get('dates', [])) | |
| amounts_count = len(normalized_result.get('amounts', [])) | |
| entities_count = len(normalized_result.get('entities', [])) | |
| await manager.send_message({ | |
| "type": "success", | |
| "step": "extraction_summary", | |
| "message": f"Extracted: {key_fields_count} key fields, {dates_count} dates, {amounts_count} amounts, {entities_count} entities", | |
| "timestamp": datetime.now().strftime("%H:%M:%S") | |
| }) | |
| # Enhanced result with metadata | |
| enhanced_result = { | |
| **normalized_result, | |
| "document_id": document_id, | |
| "processing_time": datetime.now().isoformat(), | |
| "summary": { | |
| "total_fields": key_fields_count + dates_count + amounts_count + entities_count, | |
| "key_fields_count": key_fields_count, | |
| "dates_count": dates_count, | |
| "amounts_count": amounts_count, | |
| "entities_count": entities_count, | |
| "pages_processed": result.get('total_pages', 1) | |
| } | |
| } | |
| return enhanced_result | |
| async def get_document(document_id: str): | |
| """Get document metadata and results""" | |
| if document_id not in documents_store: | |
| raise HTTPException(status_code=404, detail="Document not found") | |
| return documents_store[document_id] | |
| async def get_document_thumbnail(document_id: str): | |
| """Generate and return a thumbnail of the first page of the document""" | |
| if document_id not in documents_store: | |
| raise HTTPException(status_code=404, detail="Document not found") | |
| document = documents_store[document_id] | |
| file_path = Path(document["file_path"]) | |
| try: | |
| # Check if file exists | |
| if not file_path.exists(): | |
| raise HTTPException(status_code=404, detail="Document file not found") | |
| thumbnail_path = upload_directory / f"{document_id}_thumb.jpg" | |
| # Check if thumbnail already exists | |
| if thumbnail_path.exists(): | |
| return FileResponse(thumbnail_path, media_type="image/jpeg") | |
| # Generate thumbnail based on file type | |
| file_extension = file_path.suffix.lower() | |
| if file_extension == '.pdf': | |
| # Convert first page of PDF to image | |
| try: | |
| pages = convert_from_path(file_path, first_page=1, last_page=1, dpi=150) | |
| if pages: | |
| first_page = pages[0] | |
| # Resize to thumbnail size (max 300px width, maintain aspect ratio) | |
| first_page.thumbnail((300, 400), Image.Resampling.LANCZOS) | |
| first_page.save(thumbnail_path, 'JPEG', quality=85) | |
| else: | |
| raise HTTPException(status_code=500, detail="Could not convert PDF to image") | |
| except Exception as e: | |
| logger.error(f"PDF thumbnail generation error: {e}") | |
| raise HTTPException(status_code=500, detail=f"PDF thumbnail generation failed: {str(e)}") | |
| elif file_extension in {'.jpg', '.jpeg', '.png', '.gif', '.webp', '.bmp', '.tiff'}: | |
| # Generate thumbnail from image | |
| try: | |
| with Image.open(file_path) as img: | |
| # Convert to RGB if necessary (for JPEG output) | |
| if img.mode != 'RGB': | |
| img = img.convert('RGB') | |
| # Resize to thumbnail size | |
| img.thumbnail((300, 400), Image.Resampling.LANCZOS) | |
| img.save(thumbnail_path, 'JPEG', quality=85) | |
| except Exception as e: | |
| logger.error(f"Image thumbnail generation error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Image thumbnail generation failed: {str(e)}") | |
| else: | |
| raise HTTPException(status_code=400, detail="Unsupported file type for thumbnail") | |
| return FileResponse(thumbnail_path, media_type="image/jpeg") | |
| except Exception as e: | |
| logger.error(f"Thumbnail generation error: {e}") | |
| raise HTTPException(status_code=500, detail=f"Thumbnail generation failed: {str(e)}") | |
| async def list_documents(): | |
| """List all uploaded documents""" | |
| return list(documents_store.values()) | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "timestamp": datetime.now().isoformat(), | |
| "documents_count": len(documents_store), | |
| "active_connections": len(manager.active_connections), | |
| "api_keys_configured": { | |
| "anthropic": bool(os.getenv("ANTHROPIC_API_KEY")), | |
| "openai": bool(os.getenv("OPENAI_API_KEY")) | |
| } | |
| } | |
| async def favicon(): | |
| """Serve favicon to avoid 404 errors""" | |
| # Return a simple response to avoid 404 errors in browser | |
| return HTMLResponse(content="", status_code=204) | |
| # --- NEW, ROBUST SPA STATIC FILE SERVING LOGIC --- | |
| # This block replaces the old, problematic static file serving logic. | |
| # It correctly handles serving the React app, its assets (CSS, JS), and client-side routing. | |
| class SPAStaticFiles(StaticFiles): | |
| """ | |
| Custom StaticFiles class to serve a Single Page Application (SPA). | |
| If a requested file is not found, it serves 'index.html' to allow | |
| client-side routing (e.g., React Router) to handle the path. | |
| """ | |
| async def get_response(self, path: str, scope): | |
| try: | |
| # Try to get the file for the given path from the static directory | |
| return await super().get_response(path, scope) | |
| except (StarletteHTTPException, RuntimeError) as ex: | |
| # Check if it's a 404 Not Found error. A RuntimeError can also be raised | |
| # by Starlette for missing files under some conditions. | |
| is_404 = isinstance(ex, StarletteHTTPException) and ex.status_code == 404 | |
| if is_404 or isinstance(ex, RuntimeError): | |
| # If the file is not found, serve the 'index.html' file. | |
| # This is the key to making client-side routing work. | |
| return await super().get_response("index.html", scope) | |
| # Re-raise any other exceptions | |
| raise ex | |
| # Check if the 'static' directory (our React build output) exists in the current working directory. | |
| if os.path.exists("static"): | |
| # Mount the entire static directory at the root ("/") of the app. | |
| # The custom SPAStaticFiles class will handle serving all assets (like CSS, JS, images) | |
| # and will serve index.html as a fallback for any path that doesn't match a file. | |
| app.mount("/", SPAStaticFiles(directory="static"), name="spa-static-files") | |
| else: | |
| # If the static directory doesn't exist, this provides a helpful message for developers. | |
| async def development_landing_page(): | |
| return HTMLResponse( | |
| """ | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>OCR/LAD/RAD Platform - Backend Running</title> | |
| <style> | |
| body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0; background: #f0f2f5; text-align: center; } | |
| .container { padding: 40px; background: white; border-radius: 12px; box-shadow: 0 4px HBox(children=(HTML(value='<style> body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, H⦠| |
| </style> | |
| </head> | |
| <body> | |
| <div class="container"> | |
| <h1>π€ Backend is Running</h1> | |
| <p>Static files directory not found in <code>/backend/static</code>.</p> | |
| <p>The application is running in API-only mode. Please build the frontend or run it separately in development.</p> | |
| </div> | |
| </body> | |
| </html> | |
| """ | |
| ) | |
| def shutdown_event(): | |
| manager.stop() | |
| if __name__ == "__main__": | |
| import uvicorn | |
| # Check API keys | |
| if not os.getenv("ANTHROPIC_API_KEY"): | |
| logger.warning("ANTHROPIC_API_KEY not set") | |
| if not os.getenv("OPENAI_API_KEY"): | |
| logger.warning("OPENAI_API_KEY not set") | |
| # Get port from environment (HuggingFace uses 7860) | |
| port = int(os.getenv("PORT", 8000)) | |
| uvicorn.run( | |
| "main:app", | |
| host="0.0.0.0", | |
| port=port, | |
| reload=True if os.getenv("DEBUG", "false").lower() == "true" else False, | |
| log_level="info" | |
| ) |