Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import os | |
| import tempfile | |
| import shutil | |
| from typing import List, Optional, Dict, Any | |
| import pathlib | |
| import asyncio | |
| import logging | |
| import time | |
| import traceback | |
| import uuid | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Import our RAG components | |
| from rag import RetrievalAugmentedQAPipeline, process_file, setup_vector_db | |
| # Add local aimakerspace module to the path | |
| import sys | |
| sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "")) | |
| # Import from local aimakerspace module | |
| from aimakerspace.utils.session_manager import SessionManager | |
| # Load environment variables | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| app = FastAPI() | |
| # Configure CORS - allow all origins explicitly for development | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # This will allow all origins | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allow all methods | |
| allow_headers=["*"], # Allow all headers | |
| expose_headers=["*"] | |
| ) | |
| # Initialize session manager | |
| session_manager = SessionManager() | |
| class QueryRequest(BaseModel): | |
| session_id: str | |
| query: str | |
| class QueryResponse(BaseModel): | |
| response: str | |
| session_id: str | |
| # Set file size limit to 10MB - adjust as needed | |
| FILE_SIZE_LIMIT = 10 * 1024 * 1024 # 10MB | |
| async def process_file_background(temp_path: str, filename: str, session_id: str): | |
| """Process file in background and set up the RAG pipeline""" | |
| try: | |
| start_time = time.time() | |
| logger.info(f"Background processing started for file: {filename} (session: {session_id})") | |
| # Set max processing time (5 minutes) | |
| max_processing_time = 300 # seconds | |
| # Process the file | |
| logger.info(f"Starting text extraction for file: {filename}") | |
| try: | |
| texts = process_file(temp_path, filename) | |
| logger.info(f"Processed file into {len(texts)} text chunks (took {time.time() - start_time:.2f}s)") | |
| # Check if processing is taking too long already | |
| if time.time() - start_time > max_processing_time / 2: | |
| logger.warning(f"Text extraction took more than half the allowed time. Limiting chunks...") | |
| # Limit to a smaller number if extraction took a long time | |
| max_chunks = 50 | |
| if len(texts) > max_chunks: | |
| logger.warning(f"Limiting text chunks from {len(texts)} to {max_chunks}") | |
| texts = texts[:max_chunks] | |
| except Exception as e: | |
| logger.error(f"Error during text extraction: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| session_manager.update_session(session_id, "failed") | |
| os.unlink(temp_path) | |
| return | |
| # Setup vector database - This is the part that might be hanging | |
| logger.info(f"Starting vector DB creation for {len(texts)} chunks") | |
| embedding_start = time.time() | |
| # Create a task with overall timeout | |
| try: | |
| async def setup_with_timeout(): | |
| return await setup_vector_db(texts) | |
| # Wait for vector DB setup with timeout | |
| vector_db = await asyncio.wait_for( | |
| setup_with_timeout(), | |
| timeout=max_processing_time - (time.time() - start_time) | |
| ) | |
| # Get document count - check if documents property is available | |
| if hasattr(vector_db, 'documents'): | |
| doc_count = len(vector_db.documents) | |
| else: | |
| # If using the original VectorDatabase implementation that uses vectors dict | |
| doc_count = len(vector_db.vectors) if hasattr(vector_db, 'vectors') else 0 | |
| logger.info(f"Created vector database with {doc_count} documents (took {time.time() - embedding_start:.2f}s)") | |
| # Create RAG pipeline | |
| logger.info(f"Creating RAG pipeline for session {session_id}") | |
| rag_pipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db) | |
| # Store pipeline in session manager | |
| session_manager.update_session(session_id, rag_pipeline) | |
| logger.info(f"Updated session {session_id} with processed pipeline (total time: {time.time() - start_time:.2f}s)") | |
| except asyncio.TimeoutError: | |
| logger.error(f"Vector database creation timed out after {time.time() - embedding_start:.2f}s") | |
| session_manager.update_session(session_id, "failed") | |
| except Exception as e: | |
| logger.error(f"Error in vector database creation: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| session_manager.update_session(session_id, "failed") | |
| # Clean up temp file | |
| os.unlink(temp_path) | |
| logger.info(f"Removed temporary file: {temp_path}") | |
| except Exception as e: | |
| logger.error(f"Error in background processing for session {session_id}: {str(e)}") | |
| logger.error(traceback.format_exc()) # Log the full error traceback | |
| # Mark the session as failed rather than removing it | |
| session_manager.update_session(session_id, "failed") | |
| # Try to clean up temp file if it exists | |
| try: | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) | |
| logger.info(f"Cleaned up temporary file after error: {temp_path}") | |
| except Exception as cleanup_error: | |
| logger.error(f"Error cleaning up temp file: {str(cleanup_error)}") | |
| async def upload_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)): | |
| try: | |
| logger.info(f"Received upload request for file: {file.filename}") | |
| # Check file size first | |
| file_size = 0 | |
| chunk_size = 1024 * 1024 # 1MB chunks for reading | |
| contents = bytearray() | |
| # Read file in chunks to avoid memory issues | |
| while True: | |
| chunk = await file.read(chunk_size) | |
| if not chunk: | |
| break | |
| file_size += len(chunk) | |
| contents.extend(chunk) | |
| # Check size limit | |
| if file_size > FILE_SIZE_LIMIT: | |
| logger.warning(f"File too large: {file_size/1024/1024:.2f}MB exceeds limit of {FILE_SIZE_LIMIT/1024/1024}MB") | |
| return HTTPException( | |
| status_code=413, | |
| detail=f"File too large. Maximum size is {FILE_SIZE_LIMIT/1024/1024}MB" | |
| ) | |
| logger.info(f"File size: {file_size/1024/1024:.2f}MB") | |
| # Reset file stream for processing | |
| file_content = bytes(contents) | |
| # Create a temporary file | |
| suffix = f".{file.filename.split('.')[-1]}" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
| # Write file content to temp file | |
| temp_file.write(file_content) | |
| temp_path = temp_file.name | |
| logger.info(f"Created temporary file at: {temp_path}") | |
| # Generate session ID and create session | |
| session_id = session_manager.create_session("processing") | |
| logger.info(f"Created session ID: {session_id}") | |
| # Process file in background | |
| background_tasks.add_task( | |
| process_file_background, | |
| temp_path, | |
| file.filename, | |
| session_id | |
| ) | |
| return {"session_id": session_id, "message": "File uploaded and processing started"} | |
| except Exception as e: | |
| logger.error(f"Error processing upload: {str(e)}") | |
| logger.error(traceback.format_exc()) # Log the full error traceback | |
| raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") | |
| async def process_query(request: QueryRequest): | |
| logger.info(f"Received query request for session: {request.session_id}") | |
| # Check if session exists | |
| if not session_manager.session_exists(request.session_id): | |
| logger.warning(f"Session not found: {request.session_id}") | |
| raise HTTPException(status_code=404, detail="Session not found. Please upload a document first.") | |
| # Get session data | |
| session_data = session_manager.get_session(request.session_id) | |
| # Check if processing is still ongoing | |
| if session_data == "processing": | |
| logger.info(f"Document still processing for session: {request.session_id}") | |
| raise HTTPException(status_code=409, detail="Document is still being processed. Please try again in a moment.") | |
| # Check if processing failed | |
| if session_data == "failed": | |
| logger.error(f"Processing failed for session: {request.session_id}") | |
| raise HTTPException(status_code=500, detail="Document processing failed. Please try uploading again.") | |
| try: | |
| logger.info(f"Processing query: '{request.query}' for session: {request.session_id}") | |
| # Get response from RAG pipeline | |
| start_time = time.time() | |
| result = await session_data.arun_pipeline(request.query) | |
| # In a streaming setup, we'd handle this differently | |
| # For simplicity, we're collecting the entire response | |
| response_text = "" | |
| async for chunk in result["response"]: | |
| response_text += chunk | |
| logger.info(f"Generated response of length: {len(response_text)} (took {time.time() - start_time:.2f}s)") | |
| return { | |
| "response": response_text, | |
| "session_id": request.session_id | |
| } | |
| except Exception as e: | |
| logger.error(f"Error processing query for session {request.session_id}: {str(e)}") | |
| logger.error(traceback.format_exc()) # Log the full error traceback | |
| raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") | |
| def health_check(): | |
| return {"status": "healthy"} | |
| def test_endpoint(): | |
| return {"message": "Backend is accessible"} | |
| async def session_status(session_id: str): | |
| """Check if a session exists and its processing status""" | |
| logger.info(f"Checking status for session: {session_id}") | |
| if not session_manager.session_exists(session_id): | |
| logger.warning(f"Session not found: {session_id}") | |
| return {"exists": False, "status": "not_found"} | |
| session_data = session_manager.get_session(session_id) | |
| if session_data == "processing": | |
| logger.info(f"Session {session_id} is still processing") | |
| return {"exists": True, "status": "processing"} | |
| if session_data == "failed": | |
| logger.error(f"Session {session_id} processing failed") | |
| return {"exists": True, "status": "failed"} | |
| logger.info(f"Session {session_id} is ready") | |
| return {"exists": True, "status": "ready"} | |
| async def debug_sessions(): | |
| """Return debug information about all sessions - for diagnostic use only""" | |
| logger.info("Accessed debug sessions endpoint") | |
| # Get summary of all sessions | |
| sessions_summary = session_manager.get_sessions_summary() | |
| return sessions_summary | |
| # For Hugging Face Spaces deployment, serve the static files from the React build | |
| frontend_path = pathlib.Path(__file__).parent.parent / "frontend" / "build" | |
| if frontend_path.exists(): | |
| app.mount("/", StaticFiles(directory=str(frontend_path), html=True), name="frontend") | |
| async def serve_frontend(): | |
| return FileResponse(str(frontend_path / "index.html")) | |
| if __name__ == "__main__": | |
| uvicorn.run("main:app", host="0.0.0.0", port=8000) |