ChatbotDemo / app /app.py
Kalpokoch's picture
Update app/app.py
f1d5824 verified
raw
history blame
9.36 kB
import os
import json
import asyncio
import logging
import uuid
import re
from fastapi import FastAPI, HTTPException, Request
from pydantic import BaseModel
from llama_cpp import Llama
# Correctly reference the module within the 'app' package
from app.policy_vector_db import PolicyVectorDB, ensure_db_populated
# -----------------------------
# βœ… Logging Configuration
# -----------------------------
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - [%(request_id)s] - %(message)s')
class RequestIdAdapter(logging.LoggerAdapter):
def process(self, msg, kwargs):
return '[%s] %s' % (self.extra['request_id'], msg), kwargs
logger = logging.getLogger("app")
# -----------------------------
# βœ… Configuration
# -----------------------------
DB_PERSIST_DIRECTORY = os.getenv("DB_PERSIST_DIRECTORY", "/app/vector_database")
CHUNKS_FILE_PATH = os.getenv("CHUNKS_FILE_PATH", "/app/granular_chunks_final.jsonl")
MODEL_PATH = os.getenv("MODEL_PATH", "/app/tinyllama_dop_q4_k_m.gguf")
LLM_TIMEOUT_SECONDS = int(os.getenv("LLM_TIMEOUT_SECONDS", "90"))
RELEVANCE_THRESHOLD = float(os.getenv("RELEVANCE_THRESHOLD", "0.3"))
TOP_K_SEARCH = int(os.getenv("TOP_K_SEARCH", "3"))
TOP_K_CONTEXT = int(os.getenv("TOP_K_CONTEXT", "1"))
# -----------------------------
# βœ… Initialize FastAPI App
# -----------------------------
app = FastAPI(title="NEEPCO DoP RAG Chatbot", version="2.1.0")
@app.middleware("http")
async def add_request_id(request: Request, call_next):
request_id = str(uuid.uuid4())
request.state.request_id = request_id
response = await call_next(request)
response.headers["X-Request-ID"] = request_id
return response
# -----------------------------
# βœ… Vector DB and Data Initialization
# -----------------------------
logger.info("Initializing vector DB...")
try:
db = PolicyVectorDB(
persist_directory=DB_PERSIST_DIRECTORY,
top_k_default=TOP_K_SEARCH,
relevance_threshold=RELEVANCE_THRESHOLD
)
if not ensure_db_populated(db, CHUNKS_FILE_PATH):
logger.warning("DB not populated on startup. RAG will not function correctly.")
db_ready = False
else:
logger.info("Vector DB is populated and ready.")
db_ready = True
except Exception as e:
logger.error(f"FATAL: Failed to initialize Vector DB: {e}", exc_info=True)
db = None
db_ready = False
# -----------------------------
# βœ… Load TinyLlama GGUF Model
# -----------------------------
logger.info(f"Loading GGUF model from: {MODEL_PATH}")
try:
llm = Llama(
model_path=MODEL_PATH,
n_ctx=2048,
n_threads=1,
n_batch=512,
use_mlock=True,
verbose=False
)
logger.info("GGUF model loaded successfully.")
model_ready = True
except Exception as e:
logger.error(f"FATAL: Failed to load GGUF model: {e}", exc_info=True)
llm = None
model_ready = False
# -----------------------------
# βœ… API Schemas
# -----------------------------
class Query(BaseModel):
question: str
class Feedback(BaseModel):
request_id: str
question: str
answer: str
context_used: str
feedback: str
comment: str | None = None
# -----------------------------
# βœ… Endpoints
# -----------------------------
def get_logger_adapter(request: Request):
return RequestIdAdapter(logger, {'request_id': getattr(request.state, 'request_id', 'N/A')})
@app.get("/")
async def root():
return {"status": "βœ… Server is running."}
@app.get("/health")
async def health_check():
status = {
"status": "ok",
"database_status": "ready" if db_ready else "error",
"model_status": "ready" if model_ready else "error"
}
if not db_ready or not model_ready:
raise HTTPException(status_code=503, detail=status)
return status
async def generate_llm_response(prompt: str, request_id: str):
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None,
lambda: llm(prompt, max_tokens=1024, stop=["###", "Question:", "Context:", "</s>"], temperature=0.05, echo=False)
)
answer = response["choices"][0]["text"].strip()
if not answer:
raise ValueError("Empty response from LLM")
return answer
@app.post("/chat")
async def chat(query: Query, request: Request):
adapter = get_logger_adapter(request)
question_lower = query.question.strip().lower()
# --- GREETING & INTRO HANDLING ---
greeting_keywords = ["hello", "hi", "hey", "what can you do", "who are you"]
if question_lower in greeting_keywords:
adapter.info(f"Handling a greeting or introductory query: '{query.question}'")
intro_message = (
"Hello! I am an AI assistant specifically trained on NEEPCO's Delegation of Powers (DoP) policy document. "
"My purpose is to help you find accurate information and answer questions based on this specific dataset. "
"I am currently running on a CPU-based environment. How can I assist you with the DoP policy today?"
)
return {
"request_id": getattr(request.state, 'request_id', 'N/A'),
"question": query.question,
"context_used": "NA - Greeting",
"answer": intro_message
}
if not db_ready or not model_ready:
adapter.error("Service unavailable due to initialization failure.")
raise HTTPException(status_code=503, detail="Service is not ready. Please check logs.")
adapter.info(f"Received query: '{query.question}'")
# 1. Search Vector DB
search_results = db.search(query.question, top_k=TOP_K_SEARCH)
if not search_results:
adapter.warning("No relevant context found in vector DB.")
return {
"question": query.question,
"context_used": "No relevant context found.",
"answer": "Sorry, I could not find a relevant policy to answer that question. Please try rephrasing."
}
scores = [f"{result['relevance_score']:.4f}" for result in search_results]
adapter.info(f"Found {len(search_results)} relevant chunks with scores: {scores}")
# 2. Prepare Context
context_chunks = [result['text'] for result in search_results[:TOP_K_CONTEXT]]
context = "\n---\n".join(context_chunks)
# 3. Build Prompt with Separator Instruction
prompt = f"""<|system|>
You are a precise and factual assistant for NEEPCO's Delegation of Powers (DoP) policy.
Your task is to answer the user's question based ONLY on the provided context.
- **Formatting Rule:** If the answer contains a list of items or steps, you **MUST** separate each item with a pipe symbol (`|`). For example: `First item|Second item|Third item`.
- **Content Rule:** If the information is not in the provided context, you **MUST** reply with the exact phrase: "The provided policy context does not contain information on this topic."
</s>
<|user|>
### Relevant Context:
```
{context}
```
### Question:
{query.question}
</s>
<|assistant|>
### Detailed Answer:
"""
# 4. Generate Response
answer = "An error occurred while processing your request."
try:
adapter.info("Sending prompt to LLM for generation...")
raw_answer = await asyncio.wait_for(
generate_llm_response(prompt, request.state.request_id),
timeout=LLM_TIMEOUT_SECONDS
)
adapter.info(f"LLM generation successful. Raw response: {raw_answer[:250]}...")
# --- POST-PROCESSING LOGIC ---
# Check if the model used the pipe separator, indicating a list.
if '|' in raw_answer:
adapter.info("Pipe separator found. Formatting response as a bulleted list.")
# Split the string into a list of items
items = raw_answer.split('|')
# Clean up each item and format it as a bullet point
cleaned_items = [f"* {item.strip()}" for item in items if item.strip()]
# Join them back together with newlines
answer = "\n".join(cleaned_items)
else:
# If no separator, use the answer as is.
answer = raw_answer
except asyncio.TimeoutError:
adapter.warning(f"LLM generation timed out after {LLM_TIMEOUT_SECONDS} seconds.")
answer = "Sorry, the request took too long to process. Please try again with a simpler question."
except Exception as e:
adapter.error(f"An unexpected error occurred during LLM generation: {e}", exc_info=True)
answer = "Sorry, an unexpected error occurred while generating a response."
adapter.info(f"Final answer prepared. Returning to client.")
return {
"request_id": request.state.request_id,
"question": query.question,
"context_used": context,
"answer": answer
}
@app.post("/feedback")
async def collect_feedback(feedback: Feedback, request: Request):
adapter = get_logger_adapter(request)
feedback_log = {
"type": "USER_FEEDBACK",
"request_id": feedback.request_id,
"question": feedback.question,
"answer": feedback.answer,
"context_used": feedback.context_used,
"feedback": feedback.feedback,
"comment": feedback.comment
}
adapter.info(json.dumps(feedback_log))
return {"status": "βœ… Feedback recorded. Thank you!"}