# app.py - Updated with better request handling import os import json import asyncio import logging import uuid import re from typing import Dict, List, Optional from datetime import datetime, timedelta from fastapi import FastAPI, HTTPException, Request, BackgroundTasks from pydantic import BaseModel from llama_cpp import Llama # Correctly reference the module within the 'app' package from app.policy_vector_db import PolicyVectorDB, ensure_db_populated # ----------------------------- # ✅ Logging Configuration # ----------------------------- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - [%(request_id)s] - %(message)s') class RequestIdAdapter(logging.LoggerAdapter): def process(self, msg, kwargs): return '[%s] %s' % (self.extra['request_id'], msg), kwargs logger = logging.getLogger("app") # ----------------------------- # ✅ Queue Management Classes # ----------------------------- class QueuedRequest: def __init__(self, request_id: str, question: str, timestamp: datetime): self.request_id = request_id self.question = question self.timestamp = timestamp self.status = "queued" # queued, processing, completed, failed, timeout, cancelled self.result: Optional[Dict] = None self.error: Optional[str] = None self.cancelled = False # Track if request was cancelled self.last_accessed = datetime.now() # Track when status was last checked class RequestQueue: def __init__(self, max_size: int = 15): self.queue: List[QueuedRequest] = [] self.processing: Optional[QueuedRequest] = None self.completed_requests: Dict[str, QueuedRequest] = {} self.max_size = max_size self.lock = asyncio.Lock() self.cleanup_interval = 300 # 5 minutes self.max_completed_age = 600 # 10 minutes async def add_request(self, request_id: str, question: str) -> Dict: async with self.lock: # Clean up old requests periodically await self._cleanup_old_requests() if len(self.queue) >= self.max_size: return { "status": "queue_full", "message": f"Queue is full (max {self.max_size} requests). Please try again later.", "queue_position": None, "estimated_wait_time": None } queued_request = QueuedRequest(request_id, question, datetime.now()) # Always enqueue; the background worker is the single executor self.queue.append(queued_request) position = len(self.queue) # 1-based position in queue # Compute an ETA that reflects whether something is currently processing # ahead_of_you = (1 if a job is currently processing else 0) + (position - 1 already queued ahead) ahead_of_you = (1 if self.processing else 0) + (position - 1) # Each slot ~2 minutes based on your heuristic estimated_wait = f"{ahead_of_you * 2}-{(ahead_of_you + 1) * 2} minutes" message = ( "Using free CPU tier - can only process one request at a time. " f"Your request is #{position} in queue and will be processed after current requests are completed." ) return { "status": "queued", "message": message, "queue_position": position, "estimated_wait_time": estimated_wait } async def get_next_request(self) -> Optional[QueuedRequest]: async with self.lock: if self.queue: next_request = self.queue.pop(0) self.processing = next_request next_request.status = "processing" return next_request return None async def complete_request(self, request_id: str, result: Dict = None, error: str = None): async with self.lock: if self.processing and self.processing.request_id == request_id: if self.processing.cancelled: # Don't store results for cancelled requests self.processing.status = "cancelled" logger.info(f"Request {request_id} was cancelled, not storing result") elif result: self.processing.result = result self.processing.status = "completed" elif error: self.processing.error = error self.processing.status = "failed" # Store completed request for result retrieval (even cancelled ones briefly) self.completed_requests[request_id] = self.processing self.processing = None async def cancel_request(self, request_id: str) -> bool: """Cancel a request if it exists in queue or is processing""" async with self.lock: # Check if it's currently processing if self.processing and self.processing.request_id == request_id: self.processing.cancelled = True logger.info(f"Marked processing request {request_id} as cancelled") return True # Check if it's in queue for i, req in enumerate(self.queue): if req.request_id == request_id: cancelled_req = self.queue.pop(i) cancelled_req.status = "cancelled" cancelled_req.cancelled = True self.completed_requests[request_id] = cancelled_req logger.info(f"Cancelled queued request {request_id}") return True return False async def get_request_status(self, request_id: str) -> Optional[Dict]: async with self.lock: # Update last accessed time for any request we're checking current_time = datetime.now() # Check if currently processing if self.processing and self.processing.request_id == request_id: self.processing.last_accessed = current_time if self.processing.cancelled: return { "status": "cancelled", "message": "Request was cancelled.", "result": None, "error": "Request cancelled by user" } return { "status": self.processing.status, "message": "Your request is currently being processed.", "result": self.processing.result } # Check completed requests if request_id in self.completed_requests: req = self.completed_requests[request_id] req.last_accessed = current_time status_messages = { "completed": "Request completed.", "failed": "Request failed.", "cancelled": "Request was cancelled.", "timeout": "Request timed out." } return { "status": req.status, "message": status_messages.get(req.status, "Request processed."), "result": req.result, "error": req.error } # Check queue for i, req in enumerate(self.queue): if req.request_id == request_id: req.last_accessed = current_time return { "status": "queued", "message": f"Your request is #{i+1} in queue.", "queue_position": i + 1, "estimated_wait_time": f"{(i+1) * 2}-{(i+2) * 2} minutes" } return None async def _cleanup_old_requests(self): """Clean up old completed requests and abandoned requests""" current_time = datetime.now() cutoff_time = current_time - timedelta(seconds=self.max_completed_age) # Clean up old completed requests to_remove = [] for request_id, req in self.completed_requests.items(): if req.last_accessed < cutoff_time: to_remove.append(request_id) for request_id in to_remove: del self.completed_requests[request_id] logger.info(f"Cleaned up old request: {request_id}") async def get_queue_info(self) -> Dict: async with self.lock: return { "queue_length": len(self.queue), "currently_processing": self.processing.request_id if self.processing else None, "max_queue_size": self.max_size, "completed_requests_count": len(self.completed_requests) } # ----------------------------- # ✅ Configuration # ----------------------------- DB_PERSIST_DIRECTORY = os.getenv("DB_PERSIST_DIRECTORY", "/app/vector_database") CHUNKS_FILE_PATH = os.getenv("CHUNKS_FILE_PATH", "/app/granular_chunks_final.jsonl") MODEL_PATH = os.getenv("MODEL_PATH", "/app/tinyllama_dop_q4_k_m.gguf") LLM_TIMEOUT_SECONDS = int(os.getenv("LLM_TIMEOUT_SECONDS", "120")) RELEVANCE_THRESHOLD = float(os.getenv("RELEVANCE_THRESHOLD", "0.3")) TOP_K_SEARCH = int(os.getenv("TOP_K_SEARCH", "4")) TOP_K_CONTEXT = int(os.getenv("TOP_K_CONTEXT", "2")) MAX_QUEUE_SIZE = int(os.getenv("MAX_QUEUE_SIZE", "15")) # ----------------------------- # ✅ Initialize FastAPI App # ----------------------------- app = FastAPI(title="NEEPCO DoP RAG Chatbot", version="2.1.0") # Initialize request queue request_queue = RequestQueue(max_size=MAX_QUEUE_SIZE) @app.middleware("http") async def add_request_id(request: Request, call_next): request_id = str(uuid.uuid4()) request.state.request_id = request_id response = await call_next(request) response.headers["X-Request-ID"] = request_id return response # ----------------------------- # ✅ Vector DB and Data Initialization # ----------------------------- logger.info("Initializing vector DB...") try: db = PolicyVectorDB( persist_directory=DB_PERSIST_DIRECTORY, top_k_default=TOP_K_SEARCH, relevance_threshold=RELEVANCE_THRESHOLD ) if not ensure_db_populated(db, CHUNKS_FILE_PATH): logger.warning("DB not populated on startup. RAG will not function correctly.") db_ready = False else: logger.info("Vector DB is populated and ready.") db_ready = True except Exception as e: logger.error(f"FATAL: Failed to initialize Vector DB: {e}", exc_info=True) db = None db_ready = False # ----------------------------- # ✅ Load TinyLlama GGUF Model # ----------------------------- logger.info(f"Loading GGUF model from: {MODEL_PATH}") try: llm = Llama( model_path=MODEL_PATH, n_ctx=2048, n_threads=1, n_batch=512, use_mlock=True, verbose=False ) logger.info("GGUF model loaded successfully.") model_ready = True except Exception as e: logger.error(f"FATAL: Failed to load GGUF model: {e}", exc_info=True) llm = None model_ready = False # ----------------------------- # ✅ API Schemas # ----------------------------- class Query(BaseModel): question: str class Feedback(BaseModel): request_id: str question: str answer: str context_used: str feedback: str comment: str | None = None # ----------------------------- # ✅ Background Processing # ----------------------------- async def process_queued_requests(): """Background task to process queued requests""" while True: try: next_request = await request_queue.get_next_request() if next_request: logger.info(f"Processing queued request: {next_request.request_id}") try: # Check if request was cancelled before processing if next_request.cancelled: logger.info(f"Skipping cancelled request: {next_request.request_id}") await request_queue.complete_request( next_request.request_id, error="Request was cancelled" ) continue result = await process_chat_request(next_request.question, next_request.request_id) # Check again if request was cancelled during processing if next_request.cancelled: logger.info(f"Request was cancelled during processing: {next_request.request_id}") await request_queue.complete_request( next_request.request_id, error="Request was cancelled during processing" ) else: await request_queue.complete_request(next_request.request_id, result=result) logger.info(f"Completed request: {next_request.request_id}") except Exception as e: error_msg = f"Error processing request: {str(e)}" logger.error(f"Failed to process request {next_request.request_id}: {e}", exc_info=True) await request_queue.complete_request(next_request.request_id, error=error_msg) else: # No requests in queue, wait a bit before checking again await asyncio.sleep(2) except Exception as e: logger.error(f"Error in background processor: {e}", exc_info=True) await asyncio.sleep(5) # Start background processor @app.on_event("startup") async def startup_event(): asyncio.create_task(process_queued_requests()) # ----------------------------- # ✅ Core Processing Function # ✅ Re-ranking function for improving relevance def re_rank_by_relevance(results: List[Dict], question: str) -> List[Dict]: """Simple heuristic re-ranking based on question keyword overlap""" question_terms = set(term.lower() for term in question.split() if len(term) > 3) for result in results: chunk_terms = set(term.lower() for term in result['text'].split() if len(term) > 3) if question_terms: keyword_overlap = len(question_terms & chunk_terms) / len(question_terms) else: keyword_overlap = 0 # Boost score if chunk contains question keywords result['relevance_score'] *= (1 + 0.15 * keyword_overlap) return sorted(results, key=lambda x: x['relevance_score'], reverse=True) def get_logger_adapter(request_id: str): return RequestIdAdapter(logger, {'request_id': request_id}) async def generate_llm_response(prompt: str, request_id: str): loop = asyncio.get_running_loop() response = await loop.run_in_executor( None, lambda: llm( prompt, max_tokens=512, # Optimized for CPU performance stop=["###", "Question:", "Context:", ""], temperature=0.1, # Lower for factuality top_p=0.9, # Nucleus sampling for consistency echo=False ) ) answer = response["choices"][0]["text"].strip() if not answer: raise ValueError("Empty response from LLM") return answer async def process_chat_request(question: str, request_id: str) -> Dict: """Core chat processing logic extracted for reuse""" adapter = get_logger_adapter(request_id) question_lower = question.strip().lower() # --- GREETING & INTRO HANDLING --- greeting_keywords = ["hello", "hi", "hey", "what can you do", "who are you"] if question_lower in greeting_keywords: adapter.info(f"Handling a greeting or introductory query: '{question}'") intro_message = ( "Hello! I am an AI assistant specifically trained on NEEPCO's Delegation of Powers (DoP) policy document. " "My purpose is to help you find accurate information and answer questions based on this specific dataset. " "I am currently running on a CPU-based environment. How can I assist you with the DoP policy today?" ) return { "request_id": request_id, "question": question, "context_used": "NA - Greeting", "answer": intro_message } if not db_ready or not model_ready: adapter.error("Service unavailable due to initialization failure.") raise HTTPException(status_code=503, detail="Service is not ready. Please check logs.") adapter.info(f"Received query: '{question}'") # 1. Search Vector DB with query expansion search_results = db.search(question, top_k=TOP_K_SEARCH) # 2. Re-rank results by keyword overlap for better relevance search_results = re_rank_by_relevance(search_results, question) if not search_results: adapter.warning("No relevant context found in vector DB.") return { "request_id": request_id, "question": question, "context_used": "No relevant context found.", "answer": "Sorry, I could not find a relevant policy to answer that question. Please try rephrasing." } scores = [f"{result['relevance_score']:.4f}" for result in search_results] adapter.info(f"Found {len(search_results)} relevant chunks with scores: {scores}") # 3. Prepare Context context_chunks = [result['text'] for result in search_results[:TOP_K_CONTEXT]] context = "\n---\n".join(context_chunks) # 4. Build Enhanced Prompt prompt = f"""<|system|> You are NEEPCO's Delegation of Powers (DoP) policy expert. Answer ONLY using the provided context. - Be concise and factual - For lists/steps, use pipe separators: `Item1|Item2|Item3` - If information is absent, say: "The provided policy context does not contain information on this topic." - Do not assume or infer beyond what is stated <|user|> ### Context: {context} ### Question: {question} Answer based strictly on the context above. <|assistant|> """ # 5. Generate Response answer = "An error occurred while processing your request." try: adapter.info("Sending prompt to LLM for generation...") raw_answer = await asyncio.wait_for( generate_llm_response(prompt, request_id), timeout=LLM_TIMEOUT_SECONDS ) adapter.info(f"LLM generation successful. Raw response: {raw_answer[:250]}...") # --- POST-PROCESSING LOGIC --- # Check if the model used the pipe separator, indicating a list. if '|' in raw_answer: adapter.info("Pipe separator found. Formatting response as a bulleted list.") # Split the string into a list of items items = raw_answer.split('|') # Clean up each item and format it as a bullet point cleaned_items = [f"* {item.strip()}" for item in items if item.strip()] # Join them back together with newlines answer = "\n".join(cleaned_items) else: # If no separator, use the answer as is. answer = raw_answer except asyncio.TimeoutError: adapter.warning(f"LLM generation timed out after {LLM_TIMEOUT_SECONDS} seconds.") answer = "Sorry, the request took too long to process. Please try again with a simpler question." except Exception as e: adapter.error(f"An unexpected error occurred during LLM generation: {e}", exc_info=True) answer = "Sorry, an unexpected error occurred while generating a response." adapter.info(f"Final answer prepared. Returning result.") return { "request_id": request_id, "question": question, "context_used": context, "answer": answer } # ----------------------------- # ✅ Endpoints # ----------------------------- @app.get("/") async def root(): return {"status": "✅ Server is running."} @app.get("/health") async def health_check(): queue_info = await request_queue.get_queue_info() status = { "status": "ok", "database_status": "ready" if db_ready else "error", "model_status": "ready" if model_ready else "error", "queue_info": queue_info } if not db_ready or not model_ready: raise HTTPException(status_code=503, detail=status) return status @app.post("/chat") async def chat(query: Query, request: Request): """Add request to queue and return queue status""" if not db_ready or not model_ready: raise HTTPException(status_code=503, detail="Service is not ready. Please check logs.") request_id = request.state.request_id adapter = get_logger_adapter(request_id) adapter.info(f"Received chat request: '{query.question}'") # Add request to queue queue_status = await request_queue.add_request(request_id, query.question) return { "request_id": request_id, "question": query.question, **queue_status } @app.get("/status/{request_id}") async def get_request_status(request_id: str): """Check the status of a specific request""" try: status = await request_queue.get_request_status(request_id) if not status: raise HTTPException(status_code=404, detail="Request not found") return { "request_id": request_id, **status } except Exception as e: logger.error(f"Error checking status for {request_id}: {e}") raise HTTPException(status_code=500, detail="Error checking request status") @app.delete("/cancel/{request_id}") async def cancel_request(request_id: str): """Cancel a specific request""" try: cancelled = await request_queue.cancel_request(request_id) if not cancelled: raise HTTPException(status_code=404, detail="Request not found or cannot be cancelled") return { "status": "cancelled", "message": f"Request {request_id} has been cancelled", "request_id": request_id } except HTTPException: raise except Exception as e: logger.error(f"Error cancelling request {request_id}: {e}") raise HTTPException(status_code=500, detail="Error cancelling request") @app.get("/queue") async def get_queue_status(): """Get current queue information""" return await request_queue.get_queue_info() @app.post("/feedback") async def collect_feedback(feedback: Feedback, request: Request): adapter = get_logger_adapter(request.state.request_id) feedback_log = { "type": "USER_FEEDBACK", "request_id": feedback.request_id, "question": feedback.question, "answer": feedback.answer, "context_used": feedback.context_used, "feedback": feedback.feedback, "comment": feedback.comment } adapter.info(json.dumps(feedback_log)) return {"status": "✅ Feedback recorded. Thank you!"}