Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import asyncio | |
| import logging | |
| import uuid | |
| import re | |
| from fastapi import FastAPI, HTTPException, Request | |
| from pydantic import BaseModel | |
| from llama_cpp import Llama | |
| # Correctly reference the module within the 'app' package | |
| from app.policy_vector_db import PolicyVectorDB, ensure_db_populated | |
| # ----------------------------- | |
| # β Logging Configuration | |
| # ----------------------------- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - [%(request_id)s] - %(message)s') | |
| class RequestIdAdapter(logging.LoggerAdapter): | |
| def process(self, msg, kwargs): | |
| return '[%s] %s' % (self.extra['request_id'], msg), kwargs | |
| logger = logging.getLogger("app") | |
| # ----------------------------- | |
| # β Configuration | |
| # ----------------------------- | |
| DB_PERSIST_DIRECTORY = os.getenv("DB_PERSIST_DIRECTORY", "/app/vector_database") | |
| CHUNKS_FILE_PATH = os.getenv("CHUNKS_FILE_PATH", "/app/granular_chunks_final.jsonl") | |
| MODEL_PATH = os.getenv("MODEL_PATH", "/app/tinyllama_dop_q4_k_m.gguf") | |
| LLM_TIMEOUT_SECONDS = int(os.getenv("LLM_TIMEOUT_SECONDS", "90")) | |
| RELEVANCE_THRESHOLD = float(os.getenv("RELEVANCE_THRESHOLD", "0.3")) | |
| TOP_K_SEARCH = int(os.getenv("TOP_K_SEARCH", "3")) | |
| TOP_K_CONTEXT = int(os.getenv("TOP_K_CONTEXT", "1")) | |
| # ----------------------------- | |
| # β Initialize FastAPI App | |
| # ----------------------------- | |
| app = FastAPI(title="NEEPCO DoP RAG Chatbot", version="2.1.0") | |
| async def add_request_id(request: Request, call_next): | |
| request_id = str(uuid.uuid4()) | |
| request.state.request_id = request_id | |
| response = await call_next(request) | |
| response.headers["X-Request-ID"] = request_id | |
| return response | |
| # ----------------------------- | |
| # β Vector DB and Data Initialization | |
| # ----------------------------- | |
| logger.info("Initializing vector DB...") | |
| try: | |
| db = PolicyVectorDB( | |
| persist_directory=DB_PERSIST_DIRECTORY, | |
| top_k_default=TOP_K_SEARCH, | |
| relevance_threshold=RELEVANCE_THRESHOLD | |
| ) | |
| if not ensure_db_populated(db, CHUNKS_FILE_PATH): | |
| logger.warning("DB not populated on startup. RAG will not function correctly.") | |
| db_ready = False | |
| else: | |
| logger.info("Vector DB is populated and ready.") | |
| db_ready = True | |
| except Exception as e: | |
| logger.error(f"FATAL: Failed to initialize Vector DB: {e}", exc_info=True) | |
| db = None | |
| db_ready = False | |
| # ----------------------------- | |
| # β Load TinyLlama GGUF Model | |
| # ----------------------------- | |
| logger.info(f"Loading GGUF model from: {MODEL_PATH}") | |
| try: | |
| llm = Llama( | |
| model_path=MODEL_PATH, | |
| n_ctx=2048, | |
| n_threads=1, | |
| n_batch=512, | |
| use_mlock=True, | |
| verbose=False | |
| ) | |
| logger.info("GGUF model loaded successfully.") | |
| model_ready = True | |
| except Exception as e: | |
| logger.error(f"FATAL: Failed to load GGUF model: {e}", exc_info=True) | |
| llm = None | |
| model_ready = False | |
| # ----------------------------- | |
| # β API Schemas | |
| # ----------------------------- | |
| class Query(BaseModel): | |
| question: str | |
| class Feedback(BaseModel): | |
| request_id: str | |
| question: str | |
| answer: str | |
| context_used: str | |
| feedback: str | |
| comment: str | None = None | |
| # ----------------------------- | |
| # β Endpoints | |
| # ----------------------------- | |
| def get_logger_adapter(request: Request): | |
| return RequestIdAdapter(logger, {'request_id': getattr(request.state, 'request_id', 'N/A')}) | |
| async def root(): | |
| return {"status": "β Server is running."} | |
| async def health_check(): | |
| status = { | |
| "status": "ok", | |
| "database_status": "ready" if db_ready else "error", | |
| "model_status": "ready" if model_ready else "error" | |
| } | |
| if not db_ready or not model_ready: | |
| raise HTTPException(status_code=503, detail=status) | |
| return status | |
| async def generate_llm_response(prompt: str, request_id: str): | |
| loop = asyncio.get_running_loop() | |
| response = await loop.run_in_executor( | |
| None, | |
| lambda: llm(prompt, max_tokens=1024, stop=["###", "Question:", "Context:", "</s>"], temperature=0.05, echo=False) | |
| ) | |
| answer = response["choices"][0]["text"].strip() | |
| if not answer: | |
| raise ValueError("Empty response from LLM") | |
| return answer | |
| async def chat(query: Query, request: Request): | |
| adapter = get_logger_adapter(request) | |
| question_lower = query.question.strip().lower() | |
| # --- GREETING & INTRO HANDLING --- | |
| greeting_keywords = ["hello", "hi", "hey", "what can you do", "who are you"] | |
| if question_lower in greeting_keywords: | |
| adapter.info(f"Handling a greeting or introductory query: '{query.question}'") | |
| intro_message = ( | |
| "Hello! I am an AI assistant specifically trained on NEEPCO's Delegation of Powers (DoP) policy document. " | |
| "My purpose is to help you find accurate information and answer questions based on this specific dataset. " | |
| "I am currently running on a CPU-based environment. How can I assist you with the DoP policy today?" | |
| ) | |
| return { | |
| "request_id": getattr(request.state, 'request_id', 'N/A'), | |
| "question": query.question, | |
| "context_used": "NA - Greeting", | |
| "answer": intro_message | |
| } | |
| if not db_ready or not model_ready: | |
| adapter.error("Service unavailable due to initialization failure.") | |
| raise HTTPException(status_code=503, detail="Service is not ready. Please check logs.") | |
| adapter.info(f"Received query: '{query.question}'") | |
| # 1. Search Vector DB | |
| search_results = db.search(query.question, top_k=TOP_K_SEARCH) | |
| if not search_results: | |
| adapter.warning("No relevant context found in vector DB.") | |
| return { | |
| "question": query.question, | |
| "context_used": "No relevant context found.", | |
| "answer": "Sorry, I could not find a relevant policy to answer that question. Please try rephrasing." | |
| } | |
| scores = [f"{result['relevance_score']:.4f}" for result in search_results] | |
| adapter.info(f"Found {len(search_results)} relevant chunks with scores: {scores}") | |
| # 2. Prepare Context | |
| context_chunks = [result['text'] for result in search_results[:TOP_K_CONTEXT]] | |
| context = "\n---\n".join(context_chunks) | |
| # 3. Build Prompt with Separator Instruction | |
| prompt = f"""<|system|> | |
| You are a precise and factual assistant for NEEPCO's Delegation of Powers (DoP) policy. | |
| Your task is to answer the user's question based ONLY on the provided context. | |
| - **Formatting Rule:** If the answer contains a list of items or steps, you **MUST** separate each item with a pipe symbol (`|`). For example: `First item|Second item|Third item`. | |
| - **Content Rule:** If the information is not in the provided context, you **MUST** reply with the exact phrase: "The provided policy context does not contain information on this topic." | |
| </s> | |
| <|user|> | |
| ### Relevant Context: | |
| ``` | |
| {context} | |
| ``` | |
| ### Question: | |
| {query.question} | |
| </s> | |
| <|assistant|> | |
| ### Detailed Answer: | |
| """ | |
| # 4. Generate Response | |
| answer = "An error occurred while processing your request." | |
| try: | |
| adapter.info("Sending prompt to LLM for generation...") | |
| raw_answer = await asyncio.wait_for( | |
| generate_llm_response(prompt, request.state.request_id), | |
| timeout=LLM_TIMEOUT_SECONDS | |
| ) | |
| adapter.info(f"LLM generation successful. Raw response: {raw_answer[:250]}...") | |
| # --- POST-PROCESSING LOGIC --- | |
| # Check if the model used the pipe separator, indicating a list. | |
| if '|' in raw_answer: | |
| adapter.info("Pipe separator found. Formatting response as a bulleted list.") | |
| # Split the string into a list of items | |
| items = raw_answer.split('|') | |
| # Clean up each item and format it as a bullet point | |
| cleaned_items = [f"* {item.strip()}" for item in items if item.strip()] | |
| # Join them back together with newlines | |
| answer = "\n".join(cleaned_items) | |
| else: | |
| # If no separator, use the answer as is. | |
| answer = raw_answer | |
| except asyncio.TimeoutError: | |
| adapter.warning(f"LLM generation timed out after {LLM_TIMEOUT_SECONDS} seconds.") | |
| answer = "Sorry, the request took too long to process. Please try again with a simpler question." | |
| except Exception as e: | |
| adapter.error(f"An unexpected error occurred during LLM generation: {e}", exc_info=True) | |
| answer = "Sorry, an unexpected error occurred while generating a response." | |
| adapter.info(f"Final answer prepared. Returning to client.") | |
| return { | |
| "request_id": request.state.request_id, | |
| "question": query.question, | |
| "context_used": context, | |
| "answer": answer | |
| } | |
| async def collect_feedback(feedback: Feedback, request: Request): | |
| adapter = get_logger_adapter(request) | |
| feedback_log = { | |
| "type": "USER_FEEDBACK", | |
| "request_id": feedback.request_id, | |
| "question": feedback.question, | |
| "answer": feedback.answer, | |
| "context_used": feedback.context_used, | |
| "feedback": feedback.feedback, | |
| "comment": feedback.comment | |
| } | |
| adapter.info(json.dumps(feedback_log)) | |
| return {"status": "β Feedback recorded. Thank you!"} | |